Source code for midlearn.api

# src/midlearn/api.py

from __future__ import annotations
import warnings

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, RegressorMixin, MetaEstimatorMixin, is_classifier
from sklearn.utils.validation import check_is_fitted
from sklearn.metrics import r2_score

from . import _r_interface

[docs] class MIDRegressor(BaseEstimator, RegressorMixin): """Stand-alone Maximum Interpretation Decomposition regressor. """
[docs] def __init__( self, params_main=None, params_inter=None, penalty=0, **kwargs ): """Create a MID model. Parameters ---------- params_main : int, optional An integer specifying the maximum number of sample points for main effects. This corresponds to the 'k[1]' argument in R's `midr::interpret()`. params_inter : int, optional An integer specifying the maximum number of sample points for interactions. This corresponds to the 'k[2]' argument in R's `midr::interpret()`. penalty : float, optional The regularization penalty for pseudo-smoothing, corresponding to the 'lambda' argument in R's `midr::interpret()`. Defaults to 0. **kwargs : dict Additional keyword arguments to be passed directly to the underlying `midr::interpret()` function for advanced fitting options. """ self.params_main = params_main self.params_inter = params_inter self.penalty = penalty self.kwargs = kwargs
def fit( self, X, y, sample_weight=None ) -> MIDRegressor: """Fit the MID model to the response y on predictors X. Parameters ---------- X : array-like of shape (n_samples, n_features) Data used to train the MID model. y : array-like of shape (n_samples,) Target values. sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns ------- self : object The fitted estimator instance. """ if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) self.n_features_in_ = X.shape[1] self.feature_names_in_ = list(X.columns) self.mid_ = _r_interface._call_r_interpret( X=X, y=y, sample_weight=sample_weight, params_main=self.params_main, params_inter=self.params_inter, penalty=self.penalty, **self.kwargs ) self.is_fitted_ = True return self def r_predict( self, X, output_type: str = 'response', terms: list[str] | None = None, **kwargs ) -> np.ndarray: """A low-level method to call the R predict.mid function. This method provides a direct interface to the R function, accepting common arguments explicitly and passing any others via keyword arguments. Parameters ---------- X : pd.DataFrame or np.ndarray New data for which to make predictions. output_type : str, optional The type of prediction to return. Possible values are 'response', 'terms', or 'link'. Defaults to 'response'. terms : list of str, optional A list of specific term names to get predictions. If None, predictions for all terms are returned. Defaults to None. **kwargs : dict Additional keyword arguments to be passed directly to the underlying `midr::predict.mid()` function for advanced options. Returns ------- np.ndarray The prediction result from R. """ check_is_fitted(self) if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_in_) try: X = X[self.feature_names_in_] except KeyError as e: missing_cols = set(self.feature_names_in_) - set(X.columns) raise ValueError(f"The following columns are missing: {list(missing_cols)}") from e res = _r_interface._call_r_predict( r_object=self.mid_, X=X, output_type=output_type, terms=terms, **kwargs ) return np.asarray(res) def predict( self, X ) -> np.ndarray: """Predict target values for new data X using the fitted MID model. """ return self.r_predict(X, type='response') def predict_terms( self, X ) -> np.ndarray: """Predict the contribution of each term for new data X. """ return self.r_predict(X, type='terms') def effect( self, term: str, x: np.ndarray | pd.DataFrame, y: np.ndarray | None = None ) -> np.ndarray: """Evaluate a single MID component function for new data. Parameters ---------- term : str The name of the model term to evaluate (e.g., 'x1', 'x1:x2'). x : pd.DataFrame or np.ndarray New data for the first variable in the term. If a pd.DataFrame is provided, values of the related variables are extracted from it. y : np.ndarray, optional Values for the second variable in an interaction term. This is only required when evaluating a two-way interaction term. Returns ------- np.ndarray A NumPy array of the calculated term contributions, with the same length as x. """ check_is_fitted(self) res = _r_interface._call_r_mid_effect( r_object=self.mid_, term=term, x=x, y=y ) return np.asarray(res) @property def intercept(self): """The intercept of the fitted model. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='intercept').item() @property def weights(self): """Sample weights used to fit the model. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='weights') @property def fitted_matrix(self): """A pandas DataFrame showing the breakdown of the fitted values into the effects of the component functions. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='fitted.matrix') @property def fitted_values(self): """A NumPy array of the fitted values. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='fitted.values') @property def residuals(self): """A NumPy array of the working residuals. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='residuals') @property def ratio(self): """The ratio of the sum of squared error between the target model predictions and the fitted values, to the sum of squared deviations of the target model predictions. Corresponds to 1 - R squared. """ return _r_interface._extract_and_convert(r_object=self.mid_, name='ratio').item() def terms(self, **kwargs): """Extract term labels from the fitted model. See midr's mid.terms(). """ return list(_r_interface._call_r_mid_terms(r_object=self.mid_, **kwargs)) def main_effects(self, term: str): """Extract a pd.DataFrame representing the main effect of the specified 'term'. """ effects = _r_interface._extract_and_convert(r_object=self.mid_, name='main.effects') return _r_interface._extract_and_convert(r_object=effects, name=term) def interactions(self, term: str): """Extract a pd.DataFrame representing the interaction of the specified 'term'. """ effects = _r_interface._extract_and_convert(r_object=self.mid_, name='interactions') return _r_interface._extract_and_convert(r_object=effects, name=term) def _encoding_type(self, tag: str, order: int = 1): obj = _r_interface._extract_and_convert(r_object=self.mid_, name='encoders') obj = _r_interface._extract_and_convert(r_object=obj, name='main.effects' if order == 1 else 'interactions') obj = _r_interface._extract_and_convert(r_object=obj, name=tag) return _r_interface._extract_and_convert(r_object=obj, name='type')[0] def importance(self, **kwargs): """Create MIDImportance object from the fitted estimator. Refer to midr's mid.importance(). """ return MIDImportance(estimator=self, **kwargs) def breakdown(self, **kwargs): """Create MIDBreakdown object from the fitted estimator. Refer to midr's mid.breakdown(). """ return MIDBreakdown(estimator=self, **kwargs) def conditional(self, variable: str, **kwargs): """Create MIDConditional object from the fitted estimator. Refer to midr's mid.conditional(). """ return MIDConditional(estimator=self, variable=variable, **kwargs)
[docs] class MIDExplainer(MIDRegressor, MetaEstimatorMixin): """Surrogate Maximium Interpretation Decomposition explainer. """
[docs] def __init__( self, estimator, target_classes: str | list[str] | None = None, params_main=None, params_inter=None, penalty=0, **kwargs ): """Create a surrogate MID model to explain a pre-trained black-box model. Parameters ---------- estimator : object The pre-trained black-box model to be explained. target_classes: list of str, optional For classification estimators only. Specifies the target class or classes for which the probability is to be explained. If a list is provided, the sum of probabilities is used. If None (the default), the model explains 1 - P(class 0). params_main : int, optional An integer specifying the maximum number of sample points for main effects. This corresponds to the 'k[1]' argument in R's `midr::interpret()`. params_inter : int, optional An integer specifying the maximum number of sample points for interactions. This corresponds to the 'k[2]' argument in R's `midr::interpret()`. penalty : float, optional The regularization penalty for pseudo-smoothing, corresponding to the 'lambda' argument in R's `midr::interpret()`. Defaults to 0. **kwargs : dict Additional keyword arguments to be passed directly to the underlying `midr::interpret()` function for advanced fitting options. """ self.estimator = estimator self.target_classes = target_classes if not is_classifier(self.estimator) and self.target_classes is not None: warnings.warn( "'target_classes' is specified but will be ignored because the estimator is not a classifier.", UserWarning ) super().__init__( params_main = params_main, params_inter = params_inter, penalty = penalty, **kwargs )
def _predict_y_estimator( self, X ) -> np.ndarray: """Generate a unified continuous prediction from the estimator, handling both classifiers and regressors. """ if is_classifier(self.estimator): if not hasattr(self.estimator, "predict_proba"): raise TypeError("The provided estimator must have a 'predict_proba' method.") probas = self.estimator.predict_proba(X) if self.target_classes is not None: if not hasattr(self.estimator, 'classes_'): raise TypeError( "Estimator must have a 'classes_' attribute to use 'target_classes'." ) class_labels = np.asarray(self.estimator.classes_) target_classes = self.target_classes if not isinstance(target_classes, list): target_classes = [target_classes] target_indices = np.where(np.isin(class_labels, target_classes))[0] if len(target_indices) != len(set(target_classes)): raise ValueError( "The 'target_classes' were not appropriately found in the estimator's classes." ) return probas[:, target_indices].sum(axis=1) else: return 1 - probas[:, 0] else: if not hasattr(self.estimator, "predict"): raise TypeError("The provided estimator must have a 'predict' method.") return self.estimator.predict(X) def fit( self, X, y=None, sample_weight=None ) -> MIDExplainer: """Fit the surrogate MID model to the predictions of the estimator on X. Parameters ---------- X : array-like of shape (n_samples, n_features) Data used to train the MID model. y : array-like of shape (n_samples,), optional Predictions obtained from the original estimator. If None (the default), predictions are generated automatically from the estimator using X. sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns ------- self : object The fitted estimator (explainer) instance. """ if y is None: print("Generating predictions from the estimator...") y = self._predict_y_estimator(X) super().fit(X=X, y=y, sample_weight=sample_weight) self.estimator_ = self.estimator return self def fidelity_score( self, X, y=None, sample_weight=None ) -> float: """Calculate the fidelity of the surrogate model. This score (R-squared) measures how well this explainer's predictions match the original estimator's predictions on the data X. A score close to 1.0 means the explainer is a faithful surrogate. Parameters ---------- X : array-like of shape (n_samples, n_features) The data to evaluate fidelity on. y : array-like of shape (n_samples,), optional Predictions obtained from the original estimator. If None (the default), predictions are generated automatically from the estimator using X. sample_weight : array-like of shape (n_samples,), optional Sample weights to apply when calculating the score. Defaults to None. Returns ------- score : float The R-squared score between the original estimator's predictions and this surrogate model's predictions. """ check_is_fitted(self) if y is None: print("Generating predictions from the estimator...") y = self._predict_y_estimator(X) y_explainer = self.predict(X) return r2_score(y_true=y, y_pred=y_explainer, sample_weight=sample_weight)
[docs] class MIDImportance(object): """MID Importance. This object is returned by the `MIDRegressor.importance()` method and holds the results of the feature importance calculation. """
[docs] def __init__( self, estimator: MIDRegressor | MIDExplainer, **kwargs ): """Initialize the MIDImportance object. Parameters ---------- estimator : MIDRegressor or MIDExplainer The fitted MID model instance from which to calculate importance. **kwargs : dict Additional keyword arguments passed to the `midr::mid.importance()` function in R. """ self._obj = _r_interface._call_r_mid_importance( r_object = estimator.mid_, **kwargs )
@property def importance(self): return _r_interface._extract_and_convert(r_object=self._obj, name='importance') @property def predictions(self): return _r_interface._extract_and_convert(r_object=self._obj, name='predictions') @property def measure(self): return _r_interface._extract_and_convert(r_object=self._obj, name='measure') def terms(self, **kwargs): return list(_r_interface._call_r_mid_terms(r_object=self._obj, **kwargs))
[docs] class MIDBreakdown(object): """MID Breakdown. This object is returned by the `MIDRegressor.breakdown()` method and provides a detailed breakdown of a single prediction. """
[docs] def __init__( self, estimator: MIDRegressor | MIDExplainer, row: int | None = None, **kwargs ): """Initialize the MIDBreakdown object. Parameters ---------- estimator : MIDRegressor or MIDExplainer The fitted MID model instance to use for the breakdown. row : int, optional The specific row index (observation) in the data for which to create the breakdown. If None (the default), the breakdown for the first instance is calculated. **kwargs : dict Additional keyword arguments passed to the `midr::mid.breakdown()` function in R. """ self._obj = _r_interface._call_r_mid_breakdown( r_object = estimator.mid_, row = row, **kwargs )
@property def breakdown(self): return _r_interface._extract_and_convert(r_object=self._obj, name='breakdown') @property def data(self): return _r_interface._extract_and_convert(r_object=self._obj, name='data') @property def intercept(self): return _r_interface._extract_and_convert(r_object=self._obj, name='intercept') @property def prediction(self): return _r_interface._extract_and_convert(r_object=self._obj, name='prediction') def terms(self, **kwargs): return list(_r_interface._call_r_mid_terms(r_object=self._obj, **kwargs))
[docs] class MIDConditional(object): """ MID Conditional Expectations. This object is returned by the `MIDRegressor.conditional()` method and contains data for plotting conditional dependence. """
[docs] def __init__( self, estimator: MIDRegressor | MIDExplainer, variable: str, **kwargs ): """Initialize the MIDConditional object. Parameters ---------- estimator : MIDRegressor or MIDExplainer The fitted MID model instance to use. variable : str The name of the feature for which to calculate conditional dependence. **kwargs : dict Additional keyword arguments passed to the `midr::mid.conditional()` function in R. """ self._obj = _r_interface._call_r_mid_conditional( r_object = estimator.mid_, variable = variable, **kwargs ) self.variable = variable
@property def observed(self): return _r_interface._extract_and_convert(r_object=self._obj, name='observed') @property def conditional(self): return _r_interface._extract_and_convert(r_object=self._obj, name='conditional') @property def values(self): return _r_interface._extract_and_convert(r_object=self._obj, name='values') def terms(self, **kwargs): return list(_r_interface._call_r_mid_terms(r_object=self._obj, **kwargs))