Source code for deepmol.base.estimator

from abc import abstractmethod

from deepmol.base._serializer import Serializer
from deepmol.datasets import Dataset


[docs]class Estimator(Serializer): """ Abstract base class for estimators. An estimator is an object that can be fitted to a Dataset object. """ def __init__(self, **kwargs): """ Initialize the estimator. """ self._is_fitted = False
[docs] def fit(self, dataset: Dataset) -> 'Estimator': """ Fit the estimator to the data. Parameters ---------- dataset: Dataset The dataset to fit the estimator to. Returns ------- self: Estimator The fitted estimator. """ self._fit(dataset) self._is_fitted = True return self
@abstractmethod def _fit(self, dataset: Dataset) -> 'Estimator': """ Fit the estimator to the data. Abstract method that needs to be implemented by all subclasses. Parameters ---------- dataset: Dataset The dataset to fit the estimator to. Returns ------- self: Estimator The fitted estimator. """
[docs] def is_fitted(self) -> bool: """ Whether the estimator is fitted. Returns ------- is_fitted: bool Whether the estimator is fitted. """ return hasattr(self, '_is_fitted') and self._is_fitted