from abc import ABC, abstractmethod
from typing import Iterable
from joblib import Parallel, delayed
from deepmol.loggers.logger import Logger
[docs]class MultiprocessingClass(ABC):
"""
Base class for multiprocessing.
"""
def __init__(self, n_jobs: int = -1, process: callable = None):
"""
Constructor for the MultiprocessingClass class.
Parameters
----------
n_jobs: int
The number of jobs to use for multiprocessing. If -1, all available cores are used.
process: callable
The function to use for multiprocessing.
"""
self.n_jobs = n_jobs
self._process = process
self.logger = Logger()
@property
def process(self):
"""
Returns the function to use for multiprocessing.
"""
return self._process
[docs] def run_iteratively(self, items: list):
"""
Does not run multiprocessing due to an error pickling the process function or other.
"""
if isinstance(items[0], tuple):
for item in items:
yield self.process(*item)
else:
for item in items:
yield self.process(item)
[docs] @abstractmethod
def run(self, items: Iterable) -> Iterable:
"""
Runs the multiprocessing.
Parameters
----------
items: Iterable
The items to use for multiprocessing.
Returns
-------
results: Iterable
The results of the multiprocessing.
"""
[docs]class JoblibMultiprocessing(MultiprocessingClass):
"""
Multiprocessing class using joblib.
"""
[docs] def run(self, items: Iterable) -> Iterable:
"""
Runs the multiprocessing.
Parameters
----------
items: Iterable
The items to use for multiprocessing.
Returns
-------
results: Iterable
The results of the multiprocessing.
"""
# TODO: Add support for progress bar
try:
# verifying if the process is a zip and convert it to a list
if isinstance(items, zip):
items = list(items)
# verifying if the first element is a tuple, if so one must use the args parameter *item
if isinstance(items[0], tuple):
results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.process)(*item)
for item in items)
else:
results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.process)(item)
for item in items)
except Exception as e:
if "pickle" in str(e):
self.logger.warning(f"Failed to pickle process {self.process.__name__} function. Processing the input "
f"iteratively instead.")
results = self.run_iteratively(items)
else:
raise e
return results