Source code for deepmol.base.transformer

from abc import abstractmethod

from deepmol.base import Estimator
from deepmol.datasets import Dataset


[docs]class Transformer(Estimator): """ Abstract base class for transformers. A transformer is an object that can transform a Dataset object. """
[docs] def transform(self, dataset: Dataset) -> Dataset: """ Transform the dataset. The transformer needs to be fitted before calling this method. Parameters ---------- dataset: Dataset The dataset to transform. Returns ------- dataset: Dataset The transformed dataset. """ if not self.is_fitted: raise ValueError('Transformer needs to be fitted before calling transform()') return self._transform(dataset)
@abstractmethod def _transform(self, dataset: Dataset) -> Dataset: """ Transform the dataset. Abstract method that needs to be implemented by all subclasses. Parameters ---------- dataset: Dataset The dataset to transform. Returns ------- dataset: Dataset The transformed dataset. """
[docs] def fit_transform(self, dataset: Dataset) -> Dataset: """ Fit the transformer to the dataset and transform it. Equivalent to calling fit(dataset) and then transform(dataset). Parameters ---------- dataset: Dataset The dataset to fit and transform. Returns ------- dataset: Dataset The transformed dataset. """ return self.fit(dataset).transform(dataset)
[docs]class PassThroughTransformer(Transformer): """ A transformer that does nothing. """ def _fit(self, dataset: Dataset) -> 'PassThroughTransformer': """ Fit the transformer to the dataset. Parameters ---------- dataset: Dataset The dataset to fit the transformer to. Returns ------- self: Estimator The fitted transformer. """ return self def _transform(self, dataset: Dataset) -> Dataset: """ Transform the dataset. Parameters ---------- dataset: Dataset The dataset to transform. Returns ------- dataset: Dataset The transformed dataset. """ return dataset
[docs]class DatasetTransformer(Transformer): """ A transformer that transforms a dataset by applying a function to it. """ def __init__(self, func, **kwargs): """ Parameters ---------- func: callable The function to apply to the dataset. kwargs: dict Additional keyword arguments to pass to the function. """ super().__init__() self.func = func self.kwargs = kwargs def _transform(self, dataset: Dataset) -> Dataset: """ Transform the dataset. Parameters ---------- dataset: Dataset The dataset to transform. Returns ------- dataset: Dataset The transformed dataset. """ return self.func(dataset, **self.kwargs) def _fit(self, dataset: Dataset) -> 'DatasetTransformer': """ Fit the transformer to the dataset. Parameters ---------- dataset: Dataset The dataset to fit the transformer to. Returns ------- self: DatasetTransformer The fitted transformer. """ return self