Source code for deepmol.scalers.base_scaler

from abc import ABC

import joblib

from deepmol.base import Transformer
from deepmol.datasets import Dataset
from deepmol.utils.decorators import modify_object_inplace_decorator


[docs]class BaseScaler(ABC, Transformer): """ Abstract class for all scalers. It is used to define the interface for all scalers. """ def __init__(self, scaler, columns: list = None) -> None: """ Constructor for the BaseScaler class. """ if self.__class__ == BaseScaler: raise Exception('Abstract class BaseScaler should not be instantiated') super().__init__() self._scaler_object = scaler self.columns = columns @property def scaler_object(self): """ Returns the scaler object. Returns ------- object: The scaler object. """ return self._scaler_object @scaler_object.setter def scaler_object(self, value: object): """ Sets the scaler object. Parameters ---------- value: object The scaler object. """ self._scaler_object = value
[docs] def save(self, file_path: str) -> None: """ Saves the scaler object to a file. file_path: str The path to the file where the scaler object will be saved. """ joblib.dump(self._scaler_object, file_path)
[docs] def load(self, file_path: str) -> 'BaseScaler': """ Loads the scaler object from a file. file_path: str The path to the file where the scaler object is saved. Returns ------- object The scaler object. """ self._scaler_object = joblib.load(file_path) return self
@modify_object_inplace_decorator def scale(self, dataset: Dataset) -> Dataset: """ Scales the dataset. dataset: Dataset The dataset to be scaled. """ return self.fit_transform(dataset) def _fit(self, dataset: Dataset) -> 'BaseScaler': """ Fits the scaler with the dataset. Parameters ---------- dataset: Dataset The dataset to be fitted. Returns ------- BaseScaler The fitted scaler. """ if not self.columns: self.columns = [i for i in range(dataset.X.shape[1])] x = dataset.X[:, self.columns] self._scaler_object.fit(x) return self def _transform(self, dataset: Dataset) -> Dataset: """ Transforms the dataset. dataset: Dataset The dataset to be transformed. """ x = dataset.X[:, self.columns] res = self._scaler_object.transform(x) dataset.X[:, self.columns] = res return dataset