Source code for deepmol.parallelism.multiprocessing

from abc import ABC, abstractmethod
from typing import Iterable

from joblib import Parallel, delayed

from deepmol.loggers.logger import Logger


[docs]class MultiprocessingClass(ABC): """ Base class for multiprocessing. """ def __init__(self, n_jobs: int = -1, process: callable = None): """ Constructor for the MultiprocessingClass class. Parameters ---------- n_jobs: int The number of jobs to use for multiprocessing. If -1, all available cores are used. process: callable The function to use for multiprocessing. """ self.n_jobs = n_jobs self._process = process self.logger = Logger() @property def process(self): """ Returns the function to use for multiprocessing. """ return self._process
[docs] def run_iteratively(self, items: list): """ Does not run multiprocessing due to an error pickling the process function or other. """ if isinstance(items[0], tuple): for item in items: yield self.process(*item) else: for item in items: yield self.process(item)
[docs] @abstractmethod def run(self, items: Iterable) -> Iterable: """ Runs the multiprocessing. Parameters ---------- items: Iterable The items to use for multiprocessing. Returns ------- results: Iterable The results of the multiprocessing. """
[docs]class JoblibMultiprocessing(MultiprocessingClass): """ Multiprocessing class using joblib. """
[docs] def run(self, items: Iterable) -> Iterable: """ Runs the multiprocessing. Parameters ---------- items: Iterable The items to use for multiprocessing. Returns ------- results: Iterable The results of the multiprocessing. """ # TODO: Add support for progress bar try: # verifying if the process is a zip and convert it to a list if isinstance(items, zip): items = list(items) # verifying if the first element is a tuple, if so one must use the args parameter *item if isinstance(items[0], tuple): results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.process)(*item) for item in items) else: results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(delayed(self.process)(item) for item in items) except Exception as e: if "pickle" in str(e): self.logger.warning(f"Failed to pickle process {self.process.__name__} function. Processing the input " f"iteratively instead.") results = self.run_iteratively(items) else: raise e return results