midlearn.MIDExplainer
- class midlearn.MIDExplainer(estimator, target_classes: str | list[str] | None = None, params_main=None, params_inter=None, penalty=0, **kwargs)[source]
Surrogate Maximium Interpretation Decomposition explainer.
- __init__(estimator, target_classes: str | list[str] | None = None, params_main=None, params_inter=None, penalty=0, **kwargs)[source]
Create a surrogate MID model to explain a pre-trained black-box model.
- Parameters:
estimator (object) – The pre-trained black-box model to be explained.
target_classes (list of str, optional) – For classification estimators only. Specifies the target class or classes for which the probability is to be explained. If a list is provided, the sum of probabilities is used. If None (the default), the model explains 1 - P(class 0).
params_main (int, optional) – An integer specifying the maximum number of sample points for main effects. This corresponds to the ‘k[1]’ argument in R’s midr::interpret().
params_inter (int, optional) – An integer specifying the maximum number of sample points for interactions. This corresponds to the ‘k[2]’ argument in R’s midr::interpret().
penalty (float, optional) – The regularization penalty for pseudo-smoothing, corresponding to the ‘lambda’ argument in R’s midr::interpret(). Defaults to 0.
**kwargs (dict) – Additional keyword arguments to be passed directly to the underlying midr::interpret() function for advanced fitting options.
Methods
__init__
(estimator[, target_classes, ...])Create a surrogate MID model to explain a pre-trained black-box model.
breakdown
(**kwargs)Create MIDBreakdown object from the fitted estimator.
conditional
(variable, **kwargs)Create MIDConditional object from the fitted estimator.
effect
(term, x[, y])Evaluate a single MID component function for new data.
fidelity_score
(X[, y, sample_weight])Calculate the fidelity of the surrogate model.
fit
(X[, y, sample_weight])Fit the surrogate MID model to the predictions of the estimator on X.
get_metadata_routing
()Get metadata routing of this object.
get_params
([deep])Get parameters for this estimator.
importance
(**kwargs)Create MIDImportance object from the fitted estimator.
interactions
(term)Extract a pd.DataFrame representing the interaction of the specified 'term'.
main_effects
(term)Extract a pd.DataFrame representing the main effect of the specified 'term'.
plot
(term[, style, theme, intercept, ...])Visualize the estimated main or interaction effect of a fitted MID model with plotnine.
predict
(X)Predict target values for new data X using the fitted MID model.
predict_terms
(X)Predict the contribution of each term for new data X.
r_predict
(X[, output_type, terms])A low-level method to call the R predict.mid function.
score
(X, y[, sample_weight])Return coefficient of determination on test data.
set_fit_request
(*[, sample_weight])Configure whether metadata should be requested to be passed to the
fit
method.set_params
(**params)Set the parameters of this estimator.
set_score_request
(*[, sample_weight])Configure whether metadata should be requested to be passed to the
score
method.terms
(**kwargs)Extract term labels from the fitted model.
Attributes
fitted_matrix
A pandas DataFrame showing the breakdown of the fitted values into the effects of the component functions.
fitted_values
A NumPy array of the fitted values.
intercept
The intercept of the fitted model.
ratio
The ratio of the sum of squared error between the target model predictions and the fitted values, to the sum of squared deviations of the target model predictions.
residuals
A NumPy array of the working residuals.
weights
Sample weights used to fit the model.