midlearn.MIDExplainer

class midlearn.MIDExplainer(estimator, target_classes: str | list[str] | None = None, params_main=None, params_inter=None, penalty=0, **kwargs)[source]

Surrogate Maximium Interpretation Decomposition explainer.

__init__(estimator, target_classes: str | list[str] | None = None, params_main=None, params_inter=None, penalty=0, **kwargs)[source]

Create a surrogate MID model to explain a pre-trained black-box model.

Parameters:
  • estimator (object) – The pre-trained black-box model to be explained.

  • target_classes (list of str, optional) – For classification estimators only. Specifies the target class or classes for which the probability is to be explained. If a list is provided, the sum of probabilities is used. If None (the default), the model explains 1 - P(class 0).

  • params_main (int, optional) – An integer specifying the maximum number of sample points for main effects. This corresponds to the ‘k[1]’ argument in R’s midr::interpret().

  • params_inter (int, optional) – An integer specifying the maximum number of sample points for interactions. This corresponds to the ‘k[2]’ argument in R’s midr::interpret().

  • penalty (float, optional) – The regularization penalty for pseudo-smoothing, corresponding to the ‘lambda’ argument in R’s midr::interpret(). Defaults to 0.

  • **kwargs (dict) – Additional keyword arguments to be passed directly to the underlying midr::interpret() function for advanced fitting options.

Methods

__init__(estimator[, target_classes, ...])

Create a surrogate MID model to explain a pre-trained black-box model.

breakdown(**kwargs)

Create MIDBreakdown object from the fitted estimator.

conditional(variable, **kwargs)

Create MIDConditional object from the fitted estimator.

effect(term, x[, y])

Evaluate a single MID component function for new data.

fidelity_score(X[, y, sample_weight])

Calculate the fidelity of the surrogate model.

fit(X[, y, sample_weight])

Fit the surrogate MID model to the predictions of the estimator on X.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

importance(**kwargs)

Create MIDImportance object from the fitted estimator.

interactions(term)

Extract a pd.DataFrame representing the interaction of the specified 'term'.

main_effects(term)

Extract a pd.DataFrame representing the main effect of the specified 'term'.

plot(term[, style, theme, intercept, ...])

Visualize the estimated main or interaction effect of a fitted MID model with plotnine.

predict(X)

Predict target values for new data X using the fitted MID model.

predict_terms(X)

Predict the contribution of each term for new data X.

r_predict(X[, output_type, terms])

A low-level method to call the R predict.mid function.

score(X, y[, sample_weight])

Return coefficient of determination on test data.

set_fit_request(*[, sample_weight])

Configure whether metadata should be requested to be passed to the fit method.

set_params(**params)

Set the parameters of this estimator.

set_score_request(*[, sample_weight])

Configure whether metadata should be requested to be passed to the score method.

terms(**kwargs)

Extract term labels from the fitted model.

Attributes

fitted_matrix

A pandas DataFrame showing the breakdown of the fitted values into the effects of the component functions.

fitted_values

A NumPy array of the fitted values.

intercept

The intercept of the fitted model.

ratio

The ratio of the sum of squared error between the target model predictions and the fitted values, to the sum of squared deviations of the target model predictions.

residuals

A NumPy array of the working residuals.

weights

Sample weights used to fit the model.