Source code for deepmol.feature_importance.shap_values

import pandas as pd
import shap

from deepmol.datasets import Dataset
from deepmol.models.models import Model


[docs]class ShapValues: """ SHAP (SHapley Additive exPlanations) wrapper for DeepMol It allows to compute and analyze the SHAP values of DeepMol models. """ def __init__(self, dataset: Dataset, model: Model): """ Initialize the ShapValues object Parameters ---------- dataset: Dataset Dataset object model: Model Model object """ self.dataset = dataset self.model = model self.shap_values = None # TODO: masker not working
[docs] def computePermutationShap(self, masker: bool = False, plot: bool = True, max_evals: int = 500, **kwargs): """ Compute the SHAP values using the Permutation explainer. Parameters ---------- masker: bool If True, use a Partition masker to explain the model predictions on the given dataset plot: bool If True, plot the SHAP values max_evals: int Maximum number of iterations kwargs: dict Additional arguments for the plot function """ columns_names = self.dataset.feature_names X = pd.DataFrame(self.dataset.X, columns=columns_names, dtype=float) model = self.model.model if masker: y = self.dataset.y # build a clustering of the features based on shared information about y clustering = shap.utils.hclust(X, y) # above we implicitly used shap.maskers.Independent by passing a raw dataframe as the masker # now we explicitly use a Partition masker that uses the clustering we just computed masker = shap.maskers.Partition(X, clustering=clustering) # build a Permutation explainer and explain the model predictions on the given dataset explainer = shap.explainers.Permutation(model.predict_proba, masker) else: explainer = shap.explainers.Permutation(model.predict, X) self.shap_values = explainer(X, max_evals=max_evals) if plot: # visualize all the training set predictions if masker: shap.plots.bar(self.shap_values, **kwargs) else: shap.plots.beeswarm(self.shap_values, **kwargs)
# TODO: masker not working # TODO: too much iterations needed (remove?)
[docs] def computeExactShap(self, masker: bool = False, plot: bool = True, **kwargs): """ Compute the SHAP values using the Exact explainer. Parameters ---------- masker: bool If True, use a Partition masker to explain the model predictions on the given dataset plot: bool If True, plot the SHAP values kwargs: dict Additional arguments for the plot function """ columns_names = self.dataset.feature_names X = pd.DataFrame(self.dataset.X, columns=columns_names) model = self.model.model if masker: y = self.dataset.y # build a clustering of the features based on shared information about y clustering = shap.utils.hclust(X, y) # above we implicitly used shap.maskers.Independent by passing a raw dataframe as the masker # now we explicitly use a Partition masker that uses the clustering we just computed masker = shap.maskers.Partition(X, clustering=clustering) # build an Exact explainer and explain the model predictions on the given dataset explainer = shap.explainers.Exact(model.predict_proba, masker) else: explainer = shap.explainers.Exact(model.predict_proba, X) self.shap_values = explainer(X) if plot: # visualize all the training set predictions if masker: shap.plots.bar(self.shap_values, **kwargs) else: shap.plots.beeswarm(self.shap_values, **kwargs)
# TODO: check why force is not working (maybe java plugin is missing?)
[docs] def plotSampleExplanation(self, index: int = 0, plot_type: str = 'waterfall', **kwargs): """ Plot the SHAP values of a single sample. Parameters ---------- index: int Index of the sample to explain plot_type: str Type of plot to use. Can be 'waterfall' or 'force' kwargs: Additional arguments for the plot function. """ if self.shap_values is None: print('Shap values not computed yet! Computing shap values...') self.computeShap(plot=False) if plot_type == 'waterfall': # visualize the nth prediction's explanation shap.plots.waterfall(self.shap_values[index], **kwargs) elif plot_type == 'force': shap.initjs() # visualize the first prediction's explanation with a force plot shap.plots.force(self.shap_values[index], **kwargs) else: raise ValueError('Plot type must be waterfall or force!')
[docs] def plotFeatureExplanation(self, index: int = None, **kwargs): """ Plot the SHAP values of a single feature. Parameters ---------- index: int Index of the feature to explain kwargs: Additional arguments for the plot function. """ if index is None: # summarize the effects of all the features shap.plots.beeswarm(self.shap_values, **kwargs) else: # create a dependence scatter plot to show the effect of a single feature across the whole dataset shap.plots.scatter(self.shap_values[:, index], color=self.shap_values[:, index], **kwargs)
[docs] def plotHeatMap(self, **kwargs): """ Plot the SHAP values of all the features as a heatmap. Parameters ---------- kwargs: Additional arguments for the plot function. """ if self.shap_values is not None: shap.plots.heatmap(self.shap_values, **kwargs) else: raise ValueError('Shap values not computed yet!')
# TODO: check this again ''' def plotPositiveClass(self): shap_values2 = self.shap_values[...,1] print(shap_values2) shap.plots.bar(shap_values2) def plotNegativeClass(self): shap_values2 = self.shap_values[...,0] shap.plots.bar(shap_values2) '''