Surrogate Modeling with MID in Python

Introduction

In modern actuarial science, there is an inherent tension between predictive accuracy and model transparency. While ensemble tree-based models like Gradient Boosting Machines (GBMs) frequently outperform traditional Generalized Linear Models (GLMs), their “black-box” nature presents significant hurdles for model governance, regulatory compliance, and price filing.

This notebook demonstrates a solution using Maximum Interpretation Decomposition (MID) via the {midlearn} library in Python.

WarningCompatibility Notice

This article relies on features introduced in midlearn (>= 0.1.5) and the underlying R package midr (>= 0.6.0). Please ensure your library and package are up to date. Some arguments are not available in earlier versions.

What is MID?

MID is a functional decomposition method that deconstructs a black-box prediction function \(f(\mathbf{X})\) into several interpretable components: an intercept \(g_\emptyset\), main effects \(g_j(X_j)\), and second-order interactions \(g_{jk}(X_j, X_k)\), minimizing the squared residuals \(\mathbf{E}\left[g_D(\mathbf{X})^2\right]\):

\[ f(\mathbf{X}) = g_\emptyset + \sum_{j} g_j(X_{j}) + \sum_{j < k} g_{jk}(X_{j},\;X_{k}) + g_D(\mathbf{X}) \]

To ensure the uniqueness and identifiability of each component, MID imposes centering and probability-weighted minimum-norm constraints on the decomposition.

By approximating a black-box model with this surrogate structure, we can derive a representation that retains the superior predictive power of machine learning models without sacrificing actuarial transparency. Furthermore, it allows us to quantify the “uninterpreted” variance, i.e., the portion of the model’s logic that can’t be captured by low-order effects, via the residual term \(g_D(\mathbf{X})\).

Setting Up

We begin by setting up the environment and loading the necessary libraries.

Code
# utility
from pathlib import Path

# data manipulation
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

# predictive modeling
import sklearn.linear_model as lm
import lightgbm as lgb

# import loss function
from sklearn.metrics import mean_poisson_deviance, r2_score

# surrogate modeling
import midlearn as mid

# visualization
from plotnine import *

# load training and testing datasets
PATH = Path("../data")
train = pd.read_parquet(PATH / "train.parquet")
test  = pd.read_parquet(PATH / "test.parquet")
Error importing in API mode: ImportError('On Windows, cffi mode "ANY" is only "ABI".')
Trying to import in ABI mode.

In Python workflows, we typically partition the dataset into feature matrix \(X\), target vector \(y\), and weight vector \(w\).

Code
# training set
X_train = train.drop(['Frequency', 'Exposure'], axis=1)
y_train = train['Frequency']
w_train = train['Exposure']
# testing set
X_test  = test.drop(['Frequency', 'Exposure'], axis=1)
y_test  = test['Frequency']
w_test  = test['Exposure']

A key component of our evaluation is the Weighted Mean Poisson Deviance defined as follows:

\[ \text{Loss}(\mathbf{y}, \hat{\mathbf{y}}, \mathbf{w}) = \frac{\sum_{i=1}^n w_i\;d(y_i,\hat{y}_i)}{\sum_{i=1}^n w_i},\quad d(y_i,\hat{y}_i) = 2\left[y_i (\log y_i - \log \hat{y}_i) - (y_i - \hat{y}_i)\right] \]

The Interpretable Baseline: GLM

We first fit a GLM to establish a transparent benchmark. Since GLMs are strictly additive on the link scale, they provide a “ground truth” structure. This allows us to verify whether the MID framework can accurately recover the original coefficients and linear effects before moving to more complex black-box models.

Code
# define variable encoder for PoissonRegressor
def one_hot_encode(X):
  cats = X.select_dtypes(include=['object', 'category']).columns.tolist()
  nums = X.select_dtypes(exclude=['object', 'category']).columns.tolist()
  ct = ColumnTransformer(
    transformers=[
      ('cat', OneHotEncoder(drop='first', sparse_output=False), cats),
      ('num', 'passthrough', nums)
    ],
    verbose_feature_names_out=False
  )
  ct.set_output(transform="pandas")
  return ct.fit_transform(X)

# initialize and train a Poisson GLM
fit_glm = lm.PoissonRegressor(alpha=0, max_iter=300)
fit_glm.fit(one_hot_encode(X_train), y_train, sample_weight=w_train)
PoissonRegressor(alpha=0, max_iter=300)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code
# evaluate fitted model
y_pred_fit_glm = fit_glm.predict(one_hot_encode(X_test))

deviance = mean_poisson_deviance(
  y_true=y_test,
  y_pred=y_pred_fit_glm,
  sample_weight=w_test
)
print("Mean Poisson Deviance:", round(deviance, 6))
Mean Poisson Deviance: 0.470563

Surrogate Modeling

A surrogate model is an interpretable model trained to approximate the predictions of a black-box model \(f(\mathbf{X})\).

We build the MIDExplainer() class to the GAM to construct a first-order MID model as an interpretable surrogate of the original model.

Code
# build a surrogate model
mid_glm = mid.MIDExplainer(
  estimator=fit_glm,
  link="log"
)

mid_glm.fit(
  X_train,
  y=fit_glm.predict(one_hot_encode(X_train))
)
MIDExplainer(estimator=PoissonRegressor(alpha=0, max_iter=300), link='log')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code
# evaluate fitted surrogate
print(
  "Uninterpreted Variation Ratio:",
  round(mid_glm.ratio['working'], 6)
)

y_pred_mid_glm = mid_glm.predict(X_test)

deviance = mean_poisson_deviance(
  y_true=y_test,
  y_pred=y_pred_mid_glm,
  sample_weight=w_test
)
print("Mean Poisson Deviance:", round(deviance, 6))
Uninterpreted Variation Ratio: 0.0
Mean Poisson Deviance: 0.470563

Model Fidelity

To assess model fidelity, i.e., how closely the surrogate replicates the black-box, we calculate the uninterpreted variation ratio \(\mathbf{U}\).

\[ \mathbf{U}(f,g) = \frac{\sum_{i=1}^n (f(x_i) - g(x_i))^2}{\sum_{i=1}^n (f(x_i) - \bar{f})^2}, \quad \text{where } \bar{f} = \frac{1}{n}\sum_{i=1}^n f(x_i) \]

This metric represents the proportion of the black-box model’s variance that is not captured by the additive components of the MID model. The R-squared score, \(\mathbf{R}^2(f,g) = 1 - \mathbf{U}(f,g)\), is a standard measure for this purpose. However, it is important to note that this \(\mathbf{R}^2\) measures the fidelity to the black-box model \(f\), rather than the predictive accuracy relative to the ground truth observations as is typically the case with standard R-squared.

In the {midlearn} library, the summary output includes this ratio calculated on the training set. For models with non-linear links (e.g., Poisson regression), the “working” ratio is computed on the scale of the link function (e.g., \(\log\) scale).

To rigorously confirm the model fidelity, it is recommended to evaluate these metrics on a separate testing set. This ensures that the surrogate model is not just over-fitting the training predictions but has truly captured the underlying functional structure.

Code
# calculate R-squared on testing dataset
r2_mid = r2_score(
  y_true=np.log(fit_glm.predict(one_hot_encode(X_test))),
  y_pred=np.log(mid_glm.predict(X_test)),
  sample_weight=w_test
)
print("R squared:", round(r2_mid, 10))
R squared: 1.0

As shown by the high \(\mathbf{R}^2\) score, the MID surrogate achieves near-perfect fidelity. This level of agreement justifies using the MID components (main effects and interactions) as a reliable lens through which to interpret the original black-box model’s behavior.

Feature Effects

While the coefficients of a GLM are directly interpretable, visualizing their functional behavior across the feature space provides a more intuitive grasp of the model’s structure. In the Python ecosystem, standard libraries for GLMs often lack built-in “term plots” (common in R) to visualize partial effects on the link scale.

Here, the MID surrogate serves as a visualization tool: by appropriately replicating the PoissonRegressor, we can directly plot the main effects \(g_j(X_j)\) to verify that the linear relationships on the log scale are correctly captured.

Code
# main effects of MID surrogate
plots = []
for feature in X_train.columns:
  p = (
  mid_glm.plot(feature) +
    lims(y=[-0.8, 0.8]) +
    labs(y="Main Effect")
  )
  plots.append(p)

(
  (plots[0] | plots[1] | plots[2] | plots[5]) /
  (plots[3] | plots[4] | plots[6])
).show()

Effect Importance

Beyond simple plots for feature effects, {midr} provides a suite of diagnostic tools. First, the Effect Importance of a term \(j\) is defined as the mean absolute contribution of that term across the population:

\[ \text{Importance}_j = \mathbf{E} \left[ | g_j(X_j) | \right] \approx \frac{1}{n} \sum_{i=1}^n | g_j(x_{ij}) | \]

Code
imp_glm = mid_glm.importance(max_nsamples=2000)

p1 = (
  imp_glm.plot(fill="steelblue") +
  labs(title="Mean Absolute Effect", subtitle="Bar Plot")
)
p2 = (
  imp_glm.plot(style="sinaplot", theme="mako@div") +
  labs(subtitle="Distribution of Effects") +
  theme(legend_position="none", axis_text_y=element_blank())
)
(p1 | p2).show()

For interaction terms, the importance is similarly calculated using \(g_{jk}(X_j, X_k)\). This metric allows us to rank features by their average influence on the model’s predictions.

Conditional Expectation

Second, we can explore Individual Conditional Expectations (ICE). In the MID framework, the ICE for a feature \(j\) and a specific observation \(i\) is the expected value of the prediction as \(X_j\) varies, while keeping other features fixed at their observed values \(\mathbf{x}_{i,\setminus j}\):

\[ \text{ICE}_{i, j}(x) = g_\emptyset + g_j(x) + \sum_{k \neq j} g_{jk}(x, x_{ik}) \]

Code
ice_glm_link = mid_glm.conditional(
  type="link", variable="DrivAge", data=X_train.sample(200)
)
ice_glm = mid_glm.conditional(
  variable="DrivAge", data=X_train.sample(200)
)

p1 = (
  ice_glm_link.plot(theme="bluescale", var_color="LogDensity") +
  theme(legend_position="bottom") +
  labs(y="Linear Predictor",
       title="Conditional Expectation",
       subtitle="Change in linear predictor")
  )
p2 = (
  ice_glm.plot(style="centered", theme="bluescale",
               var_color="LogDensity") +
  theme(legend_position = "bottom") +
  labs(y="Prediction", title="",
       subtitle="Change in original scale")
  )
(p1 | p2).show()

Unlike standard black-box models, MID’s low-order structure allows us to compute these expectations efficiently and interpret the variation across curves (the “thickness” of the ICE plot) as a direct consequence of specified interaction terms \(g_{jk}\).

Additive Attribution

Third, we perform instance-level explanation through an Additive Breakdown of the prediction. For any single observation \(\mathbf{x}\), the MID surrogate’s prediction \(g(\mathbf{x})\) is decomposed into the exact sum of its functional components:

\[ g(\mathbf{x}) = \underbrace{g_\emptyset}_{\text{Intercept}} + \underbrace{\sum_{j} g_j(x_j)}_{\text{Main Effects}} + \underbrace{\sum_{j < k} g_{jk}(x_j, x_k)}_{\text{Interactions}} \]

Code
np.random.seed(42)
row_ids = sorted(np.random.randint(0, train.shape[0], 4).tolist())

bd_plots = []
for idx in row_ids:
  bd = mid_glm.breakdown(row=idx)
  label = "Breakdown for Row " + str(idx)
  p = (
    bd.plot(theme="shap", format_args={'digits': 2}) +
    labs(x="", subtitle=label) +
    theme(legend_position="none")
  )
  bd_plots.append(p)

(
  (bd_plots[0] | bd_plots[1]) /
  (bd_plots[2] | bd_plots[3]) 
).show()

By visualizing these contributions in a waterfall plot, we can identify which specific risk factors or interaction effects drove the prediction for a particular instance, such as a high-risk policyholder.

The Black-Box: LightGBM

While GLMs are transparent, GBMs such as LightGBM often yield superior predictive power by capturing high-order interactions. However, this accuracy comes at the cost of being a black box.

Code
# data preprocessing for LightGBM
def str_to_cat(X):
  cats = X.select_dtypes(include=['object', 'category']).columns.tolist()
  X = X.copy()
  X[cats] = X[cats].astype('category')
  return X

# model parameters (mirroring the R version)
params_lgb = {
  'objective': "poisson",
  'n_estimators': 551,
  'learning_rate': 0.01672663973358928,
  'num_leaves': 61,
  'max_depth': 19,
  'min_child_samples': 10,
  'subsample': 0.8123000876841823,
  'colsample_bytree': 0.6507848978461632,
  'reg_alpha': 4.234603347091384,
  'reg_lambda': 8.790496879009705e-07,
  'random_state': 42,
  'n_jobs': -1,
  'verbosity': -1,
  'importance_type': 'gain'
}

# split datasets for validation
valid_n = int(X_train.shape[0] * 0.2)
train_X, valid_X = X_train.iloc[valid_n:], X_train.iloc[:valid_n]
train_y, valid_y = y_train[valid_n:], y_train[:valid_n]
train_w, valid_w = w_train[valid_n:], w_train[:valid_n]

# initialize and train a LightGBM
fit_lgb = lgb.LGBMRegressor(**params_lgb)

fit_lgb.fit(
  X=str_to_cat(train_X),
  y=train_y,
  sample_weight=train_w,
  eval_set=[(valid_X, valid_y)],
  eval_sample_weight=[valid_w],
  callbacks=[lgb.early_stopping(stopping_rounds=50)]
)
Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[458]   valid_0's poisson: 0.263843
LGBMRegressor(colsample_bytree=0.6507848978461632, importance_type='gain',
              learning_rate=0.01672663973358928, max_depth=19,
              min_child_samples=10, n_estimators=551, n_jobs=-1, num_leaves=61,
              objective='poisson', random_state=42, reg_alpha=4.234603347091384,
              reg_lambda=8.790496879009705e-07, subsample=0.8123000876841823,
              verbosity=-1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code
# evaluate fitted model
y_pred_fit_lgb = fit_lgb.predict(str_to_cat(X_test))

deviance = mean_poisson_deviance(
  y_true=y_test,
  y_pred=y_pred_fit_lgb,
  sample_weight=w_test
)
print("Mean Poisson Deviance:", round(deviance, 6))
Mean Poisson Deviance: 0.465148

Surrogate Modeling

We use {midlearn} to replicate the LightGBM model. By including interaction terms in the model formula, we allow the surrogate to capture the joint relationships that the GBM has learned. The goal is to approximate the LightGBM function \(f_{LGB}(\mathbf{x})\) with our interpretable structure \(g(\mathbf{x})\):

\[ f_{LGB}(\mathbf{x}) \approx g(\mathbf{x}) = g_\emptyset + \sum_{j} g_j(x_j) + \sum_{j < k} g_{jk}(x_j, x_k) \]

WarningComputational Considerations

Including all second-order interactions using the (...)^2 syntax results in \(p(p-1)/2\) interaction terms. For high-dimensional data, this can be memory-intensive. Users should ensure sufficient RAM is available or consider limiting the formula to the most relevant features, or using a subset of the training set.

Code
# build a surrogate model
mid_lgb = mid.MIDExplainer(
  estimator=fit_lgb,
  interactions=True,
  link="log",
  penalty=0.01
)

mid_lgb.fit(
  X_train, 
  y=fit_lgb.predict(X_train),
  sample_weight=w_train
)
MIDExplainer(estimator=LGBMRegressor(colsample_bytree=0.6507848978461632,
                                     importance_type='gain',
                                     learning_rate=0.01672663973358928,
                                     max_depth=19, min_child_samples=10,
                                     n_estimators=551, n_jobs=-1, num_leaves=61,
                                     objective='poisson', random_state=42,
                                     reg_alpha=4.234603347091384,
                                     reg_lambda=8.790496879009705e-07,
                                     subsample=0.8123000876841823,
                                     verbosity=-1),
             link='log', penalty=0.01)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Code
# evaluate fitted surrogate
print(
  "Uninterpreted Variation Ratio:",
  round(mid_lgb.ratio['working'], 6)
)

y_pred_mid_lgb = mid_lgb.predict(X_test)

deviance = mean_poisson_deviance(
  y_true=y_test,
  y_pred=y_pred_mid_lgb,
  sample_weight=w_test
)
print("Mean Poisson Deviance:", round(deviance, 6))
Uninterpreted Variation Ratio: 0.093125
Mean Poisson Deviance: 0.467089

Model Fidelity

To measure how successfully our surrogate replicates the LightGBM model, we use the R-squared score on the link scale (\(\log\) scale).

Code
# calculate R-squared on testing dataset
r2_mid = r2_score(
  y_true=np.log(fit_lgb.predict(str_to_cat(X_test))),
  y_pred=np.log(mid_lgb.predict(X_test)),
  sample_weight=w_test
)
print("R squared:", round(r2_mid, 10))
R squared: 0.9051021397

Feature Effects

Code
# main effects of MID surrogate
plots = []
for feature in X_train.columns:
  p = (
  mid_lgb.plot(feature) +
    lims(y=[-1.0, 1.0]) +
    labs(y="Main Effect")
  )
  plots.append(p)

(
  (plots[0] | plots[1] | plots[2] | plots[5]) /
  (plots[3] | plots[4] | plots[6])
).show()

A key advantage of {midlearn} is its ability to isolate interaction effects \(g_{jk}\) from main effects \(g_j\). This is particularly useful to understand the joint impact of two variables (e.g., Region and LogDensity).

Code
p1 = (
  mid_lgb.plot("LogDensity:Region", style="data",
               data=train.sample(2000), theme="bluescale") +
  labs(subtitle="Interaction Effect", y="") +
  theme(legend_position="bottom")
)
p2 = (
  mid_lgb.plot("LogDensity:Region", main_effects=True) +
  labs(subtitle="Total Effect", y="") +
  theme(axis_text_y=element_blank(),
        legend_position="bottom")
)
(p1 | p2).show()

Effect Importance

To rank the influence of each component discovered in the LightGBM model, we calculate the Effect Importance, defined as the average absolute contribution.

Code
imp_lgb = mid_lgb.importance(max_nsamples=2000)

p1 = (
  imp_lgb.plot(theme="bluescale@qual", max_nterms=20) +
  labs(title="Effect Importance", subtitle="Average absolute effect") +
  theme(legend_position="none")
)
p2 = (
  imp_lgb.plot(style="sinaplot", theme="mako@div", max_nterms=20) +
  labs(title="", subtitle="Distribution of effects") +
  theme(axis_text_y=element_blank(), legend_position="none")
)
(p1 | p2).show()

Conditional Expectation

We further explore the model’s behavior using the ICE plot. In the MID framework, the variation in ICE curves for a feature \(j\) is explicitly governed by the interaction terms \(g_{jk}\) identified from the LightGBM model.

Code
ice_lgb_link = mid_lgb.conditional(
  type="link", variable="DrivAge", data=X_train.sample(200)
)
ice_lgb = mid_lgb.conditional(
  variable="DrivAge", data=X_train.sample(200)
)

p1 = (
  ice_lgb_link.plot(theme="bluescale", var_color="LogDensity") +
  theme(legend_position="bottom") +
  labs(y="Linear Predictor",
       title="Conditional Expectation",
       subtitle="Change in linear predictor")
  )
p2 = (
  ice_lgb.plot(style="centered", theme="bluescale",
               var_color="LogDensity") +
  theme(legend_position = "bottom") +
  labs(y = "Prediction", title = "",
     subtitle = "Change in original scale")
  )
(p1 | p2).show()

Additive Attribution

Finally, we perform an Additive Breakdown for individual predictions. This provides an exact allocation of the LightGBM’s prediction into the terms of our surrogate model.

Code
np.random.seed(42)
row_ids = sorted(
  np.random.randint(0, train.shape[0], 4).tolist()
)
bd_plots = []
for idx in row_ids:
  bd = mid_lgb.breakdown(row=idx)
  label = "Breakdown for Row " + str(idx)
  p = (
    bd.plot(theme="shap", format_args={'digits': 4},
            max_nterms=10) +
    labs(x="", subtitle=label) +
    theme(legend_position="none")
  )
  bd_plots.append(p)

(
  (bd_plots[0] | bd_plots[1]) /
  (bd_plots[2] | bd_plots[3]) 
).show()

Conclusion

In this notebook, we have demonstrated how Maximum Interpretation Decomposition (MID) bridges the gap between predictive performance and model transparency. By using the {midlearn} library, we successfully transformed a complex LightGBM model into a structured, additive representation.

While the surrogate model fidelity may not always be perfect, the crucial advantage lies in our ability to quantify its limitations. Through the uninterpreted variation ratio, we can directly assess the complexity of the black-box model. If the fidelity is lower than expected, it serves as a diagnostic signal that the original model relies on high-order interactions or structural complexities that extend beyond second-order effects.

Knowing the extent of this “unexplained” variance is far more valuable than operating in the dark. It allows actuaries to make informed decisions about whether the additional complexity of a black-box model is justified by its performance, or if a more transparent structure is preferable for regulatory and risk management purposes.

As machine learning models become increasingly prevalent in insurance pricing and reserving, tools like MID will be essential for ensuring that our “black-boxes” remain accountable, reliable, and fundamentally understood.