Surrogate Modeling with MID in R

Introduction

In modern actuarial science, there is an inherent tension between predictive accuracy and model transparency. While ensemble tree-based models like Gradient Boosting Machines (GBMs) frequently outperform traditional Generalized Linear Models (GLMs), their “black-box” nature presents significant hurdles for model governance, regulatory compliance, and price filing.

This notebook demonstrates a solution using Maximum Interpretation Decomposition (MID) via the {midr} and {midnight} packages in R.

WarningCompatibility Notice

This article relies on features introduced in midr (>= 0.6.0) and midnight (>= 0.1.1.902). Please ensure your packages are up to date. Some visualization arguments are not available in earlier versions.

What is MID?

MID is a functional decomposition method that deconstructs a black-box prediction function \(f(\mathbf{X})\) into several interpretable components: an intercept \(g_\emptyset\), main effects \(g_j(X_j)\), and second-order interactions \(g_{jk}(X_j, X_k)\), minimizing the squared residuals \(\mathbf{E}\left[g_D(\mathbf{X})^2\right]\):

\[ f(\mathbf{X}) = g_\emptyset + \sum_{j} g_j(X_{j}) + \sum_{j < k} g_{jk}(X_{j},\;X_{k}) + g_D(\mathbf{X}) \]

To ensure the uniqueness and identifiability of each component, MID imposes centering and probability-weighted minimum-norm constraints on the decomposition.

By approximating a black-box model with this surrogate structure, we can derive a representation that retains the superior predictive power of machine learning models without sacrificing actuarial transparency. Furthermore, it allows us to quantify the “uninterpreted” variance, i.e., the portion of the model’s logic that can’t be captured by low-order effects, via the residual term \(g_D(\mathbf{X})\).

Setting Up

We begin by setting up the environment and loading the necessary libraries.

Code
# data manipulation
library(arrow)
library(dplyr)

# predictive modeling
library(gam)
library(lightgbm)

# surrogate modeling
library(midr)
library(midnight)
nightfall(methods = TRUE, solvers = FALSE, themes = FALSE)

# visualization
library(ggplot2)
library(gridExtra)

# load training and testing datasets
train <- read_parquet("../data/train.parquet")
test  <- read_parquet("../data/test.parquet")

A key component of our evaluation is the Weighted Mean Poisson Deviance defined as follows:

\[ \text{Loss}(\mathbf{y}, \hat{\mathbf{y}}, \mathbf{w}) = \frac{\sum_{i=1}^n w_i\;d(y_i,\hat{y}_i)}{\sum_{i=1}^n w_i},\quad d(y_i,\hat{y}_i) = 2\left[y_i (\log y_i - \log \hat{y}_i) - (y_i - \hat{y}_i)\right] \]

Code
# define loss function
mean_poisson_deviance <- function(
    y_true, y_pred, sample_weight = rep(1, length(y))
  ) {
  stopifnot(all(y_pred > 0))
  resid <- ifelse(y_true > 0, y_true * log(y_true / y_pred), 0)
  resid <- resid - y_true + y_pred
  2 * sum(resid * sample_weight) / sum(sample_weight)
}

The Interpretable Baseline: GAM

We first fit a GAM to establish a transparent benchmark. Since GAMs are additive by design, they provide a “ground truth” model structure to be recovered by the functional decomposition.

Code
fit_gam <- gam(
  Frequency ~ s(VehPower) + s(VehAge) + s(DrivAge) + s(LogDensity) +
              VehBrand + VehGas + Region,
  data = train,
  weights = Exposure,
  family = quasipoisson(link = "log")
)

summary(fit_gam)

Call: gam(formula = Frequency ~ s(VehPower) + s(VehAge) + s(DrivAge) + 
    s(LogDensity) + VehBrand + VehGas + Region, family = quasipoisson(link = "log"), 
    data = train, weights = Exposure)
Deviance Residuals:
    Min      1Q  Median      3Q     Max 
-0.7606 -0.3400 -0.2556 -0.1401 10.7719 

(Dispersion Parameter for quasipoisson family taken to be 1.7241)

    Null Deviance: 86471.6 on 338994 degrees of freedom
Residual Deviance: 84868.97 on 338966 degrees of freedom
AIC: NA 

Number of Local Scoring Iterations: NA 

Anova for Parametric Effects
                  Df Sum Sq Mean Sq  F value    Pr(>F)    
s(VehPower)        1     45   45.32  26.2882 2.942e-07 ***
s(VehAge)          1     35   34.96  20.2764 6.704e-06 ***
s(DrivAge)         1    320  319.82 185.4987 < 2.2e-16 ***
s(LogDensity)      1    531  531.36 308.1989 < 2.2e-16 ***
VehBrand           5     91   18.14  10.5199 4.072e-10 ***
VehGas             1     56   55.72  32.3196 1.309e-08 ***
Region             6     73   12.15   7.0499 1.606e-07 ***
Residuals     338966 584406    1.72                       
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Anova for Nonparametric Effects
              Npar Df Npar F     Pr(F)    
(Intercept)                               
s(VehPower)         3  1.320    0.2659    
s(VehAge)           3 11.223 2.328e-07 ***
s(DrivAge)          3 83.378 < 2.2e-16 ***
s(LogDensity)       3  1.052    0.3684    
VehBrand                                  
VehGas                                    
Region                                    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Code
# evaluate fitted model
pred_fit_gam <- predict(fit_gam, test, type = "response")

deviance <- mean_poisson_deviance(
  y_true = test$Frequency,
  y_pred = pred_fit_gam,
  sample_weight = test$Exposure
)
cat("Mean Poisson Deviance:", deviance)
Mean Poisson Deviance: 0.4679338

Surrogate Modeling

A surrogate model is an interpretable model trained to approximate the predictions of a black-box model \(f(\mathbf{X})\).

We apply the interpret() function to the GAM to construct a first-order MID model as an interpretable surrogate of the original model.

Code
mid_gam <- interpret(
  Frequency ~ VehPower + VehAge + DrivAge + LogDensity +
              VehBrand + VehGas + Region,
  data = train,
  weights = Exposure,
  link = "log",
  model = fit_gam
)

summary(mid_gam)

Call:
interpret(formula = Frequency ~ VehPower + VehAge + DrivAge +
 LogDensity + VehBrand + VehGas + Region, data = train, model = fit_gam,
 weights = Exposure, link = "log")

Link: log

Uninterpreted Variation Ratio:
   working   response 
3.7625e-05 2.5412e-05 

Working Residuals:
        Min          1Q      Median          3Q         Max 
-0.01317382 -0.00069381 -0.00004593  0.00057309  0.01821350 

Encoding:
           main.effect
VehPower     linear(9)
VehAge      linear(18)
DrivAge     linear(25)
LogDensity  linear(25)
VehBrand     factor(6)
VehGas       factor(2)
Region       factor(7)
Code
# evaluate fitted surrogate
pred_mid_gam = predict(mid_gam, test, type = "response")

deviance <- mean_poisson_deviance(
  y_true = test$Frequency,
  y_pred = pred_mid_gam,
  sample_weight = test$Exposure
)
cat("Mean Poisson Deviance:", deviance)
Mean Poisson Deviance: 0.4679292

Model Fidelity

To assess model fidelity, i.e., how closely the surrogate replicates the black-box, we calculate the uninterpreted variation ratio \(\mathbf{U}\).

\[ \mathbf{U}(f,g) = \frac{\sum_{i=1}^n (f(x_i) - g(x_i))^2}{\sum_{i=1}^n (f(x_i) - \bar{f})^2}, \quad \text{where } \bar{f} = \frac{1}{n}\sum_{i=1}^n f(x_i) \]

This metric represents the proportion of the black-box model’s variance that is not captured by the additive components of the MID model. The R-squared score, \(\mathbf{R}^2(f,g) = 1 - \mathbf{U}(f,g)\), is a standard measure for this purpose. However, it is important to note that this \(\mathbf{R}^2\) measures the fidelity to the black-box model \(f\), rather than the predictive accuracy relative to the ground truth observations as is typically the case with standard R-squared.

In the {midr} package, the summary output includes this ratio calculated on the training set. For models with non-linear links (e.g., Poisson regression), the “working” ratio is computed on the scale of the link function (e.g., \(\log\) scale).

To rigorously confirm the model fidelity, it is recommended to evaluate these metrics on a separate testing set. This ensures that the surrogate model is not just over-fitting the training predictions but has truly captured the underlying functional structure.

Code
# calculate R-squared on testing dataset
R2_mid <- weighted.loss(
  x = log(pred_fit_gam),
  y = log(pred_mid_gam),
  w = test$Exposure,
  method = "r2"
)

cat(sprintf("R-squared: %.6f", R2_mid))
R-squared: 0.999963

As shown by the high \(\mathbf{R}^2\) score, the MID surrogate achieves near-perfect fidelity. This level of agreement justifies using the MID components (main effects and interactions) as a reliable lens through which to interpret the original black-box model’s behavior.

Feature Effects

Visualizing the functional behavior of each component allows for a direct comparison between the MID surrogate’s decomposition and the original GAM’s structure.

  • MID Surrogate
Code
# main effects of MID surrogate
par.midr(mfrow = c(2, 4))
mid.plots(mid_gam, engine = "graphics", ylab = "Main Effect")

  • Original GAM
Code
# feature effects of GAM
par.midr(mfrow = c(2, 4))
termplot(fit_gam)

Furthermore, we can visualize the joint effects of feature pairs as 3D prediction surfaces using the S3 method for the persp() function.

Code
par.midr(mar = c(1, 0, 1, 0), mfrow = c(1, 2))
persp(mid_gam, "DrivAge:LogDensity", theta = 45, phi = 40)
persp(mid_gam, "LogDensity:Region", theta = -45, phi = 40)

Effect Importance

Beyond simple plots for feature effects, {midr} provides a suite of diagnostic tools. First, the Effect Importance of a term \(j\) is defined as the mean absolute contribution of that term across the population:

\[ \text{Importance}_j = \mathbf{E} \left[ | g_j(X_j) | \right] \approx \frac{1}{n} \sum_{i=1}^n | g_j(x_{ij}) | \]

Code
imp_gam <- mid.importance(mid_gam, data = train, max.nsamples = 2000)
grid.arrange(
  nrow = 1, widths = c(5, 4),
  ggmid(imp_gam, fill = "steelblue") +
    labs(title = "Effect Importance",
         subtitle = "Average absolute effect per feature"),
  ggmid(imp_gam, type = "beeswarm", theme = "mako@div") +
    labs(title = "",
         subtitle = "Distribution of effect per feature") +
    scale_y_discrete(labels = NULL) +
    theme(legend.position = "none")
)

For interaction terms, the importance is similarly calculated using \(g_{jk}(X_j, X_k)\). This metric allows us to rank features by their average influence on the model’s predictions.

Conditional Expectation

Second, we can explore Individual Conditional Expectations (ICE). In the MID framework, the ICE for a feature \(j\) and a specific observation \(i\) is the expected value of the prediction as \(X_j\) varies, while keeping other features fixed at their observed values \(\mathbf{x}_{i,\setminus j}\):

\[ \text{ICE}_{i, j}(x) = g_\emptyset + g_j(x) + \sum_{k \neq j} g_{jk}(x, x_{ik}) \]

Code
ice_gam_link <- mid.conditional(mid_gam, type = "link", variable = "DrivAge")
ice_gam <- mid.conditional(mid_gam, variable = "DrivAge")
grid.arrange(
  nrow = 1,
  ggmid(ice_gam_link, var.color = LogDensity) +
    theme(legend.position = "bottom") +
    labs(y = "Linear Predictor",
         title = "Conditional Expectation",
         subtitle = "Change in linear predictor"),
  ggmid(ice_gam, type = "centered", var.color = LogDensity) +
    theme(legend.position = "bottom") +
    labs(y = "Prediction", title = "",
         subtitle = "Centered change in original scale")
)

Unlike standard black-box models, MID’s low-order structure allows us to compute these expectations efficiently and interpret the variation across curves (the “thickness” of the ICE plot) as a direct consequence of specified interaction terms \(g_{jk}\).

Additive Attribution

Third, we perform instance-level explanation through an Additive Breakdown of the prediction. For any single observation \(\mathbf{x}\), the MID surrogate’s prediction \(g(\mathbf{x})\) is decomposed into the exact sum of its functional components:

\[ g(\mathbf{x}) = \underbrace{g_\emptyset}_{\text{Intercept}} + \underbrace{\sum_{j} g_j(x_j)}_{\text{Main Effects}} + \underbrace{\sum_{j < k} g_{jk}(x_j, x_k)}_{\text{Interactions}} \]

Code
set.seed(42)
row_ids <- sort(sample(nrow(train), 4))
bd_list <- lapply(
  row_ids,
  function(x) {
    res <- mid.breakdown(mid_gam, train, row = x)
    structure(res, row_id = x)
  }
)
bd_plots <- lapply(
  bd_list, function(x) {
    label <- paste0("Breakdown for Row ", attr(x, "row_id"))
    ggmid(x, theme = "shap", format.args = list(digits = 4)) +
      labs(x = NULL, subtitle = label) +
      theme(legend.position = "none")
  }
)
grid.arrange(grobs = bd_plots)

By visualizing these contributions in a waterfall plot, we can identify which specific risk factors or interaction effects drove the prediction for a particular instance, such as a high-risk policyholder.

The Black-Box: LightGBM

While GAMs are transparent, GBMs such as LightGBM often yield superior predictive power by capturing high-order interactions. However, this accuracy comes at the cost of being a black box.

Code
# hold out validation dataset
valid_idx <- seq_len(floor(nrow(train) * 0.2))

# create datasets for training
dtrain <- lgb.Dataset(
  data.matrix(select(train[-valid_idx, ], -Frequency, -Exposure)),
  label = train$Frequency[-valid_idx],
  weight = train$Exposure[-valid_idx],
  categorical_feature = c("VehBrand", "VehGas", "Region")
)

dvalid <- lgb.Dataset.create.valid(
  dtrain,
  data.matrix(select(train[ valid_idx, ], -Frequency, -Exposure)),
  label = train$Frequency[ valid_idx],
  weight = train$Exposure[ valid_idx]
)

# model parameters
params_lgb <- list(
  objective = "poisson",
  learning_rate = 0.03188002,
  num_leaves = 30,
  reg_lambda = 0.004201069,
  reg_alpha = 0.2523909,
  colsample_bynode = 0.5552524,
  subsample = 0.5938199,
  min_child_samples = 9,
  min_split_gain = 0.3920509,
  poisson_max_delta_step = 0.8039541
)

set.seed(42)
fit_lgb <- lgb.train(
  params = params_lgb,
  data = dtrain,
  nrounds = 1000L,
  valids = list(eval = dvalid),
  early_stopping_round = 50L,
  verbose = -1L
)

summary(fit_lgb)
LightGBM Model (402 trees)
Objective: poisson
Fitted to dataset with 7 columns
Code
# evaluate fitted model
pred_fit_lgb <- predict(
  fit_lgb, data.matrix(select(test, -Frequency, -Exposure))
)

deviance <- mean_poisson_deviance(
  y_true = test$Frequency,
  y_pred = pred_fit_lgb,
  sample_weight = test$Exposure
)
cat("Mean Poisson Deviance:", deviance)
Mean Poisson Deviance: 0.4655023

Surrogate Modeling

We use {midr} to replicate the LightGBM model. By including interaction terms in the model formula, we allow the surrogate to capture the joint relationships that the GBM has learned. The goal is to approximate the LightGBM function \(f_{LGB}(\mathbf{x})\) with our interpretable structure \(g(\mathbf{x})\):

\[ f_{LGB}(\mathbf{x}) \approx g(\mathbf{x}) = g_\emptyset + \sum_{j} g_j(x_j) + \sum_{j < k} g_{jk}(x_j, x_k) \]

WarningComputational Considerations

Including all second-order interactions using the (...)^2 syntax results in \(p(p-1)/2\) interaction terms. For high-dimensional data, this can be memory-intensive. Users should ensure sufficient RAM is available or consider limiting the formula to the most relevant features, or using a subset of the training set.

Code
mid_lgb <- interpret(
  Frequency ~ (VehPower + VehAge + DrivAge + LogDensity +
               VehBrand + VehGas + Region)^2,
  data = train,
  lambda = 0.01,
  weights = Exposure,
  link = "log",
  model = fit_lgb,
  pred.fun = function(model, data) {
    newdata <- data.matrix(select(data, -Frequency, -Exposure))
    predict(model, newdata)
  }
)

summary(mid_lgb)

Call:
interpret(formula = Frequency ~ (VehPower + VehAge + DrivAge +
 LogDensity + VehBrand + VehGas + Region)^2, data = train,
 model = fit_lgb, pred.fun = function(model, data) {
 newdata <- data.matrix(select(data, -Frequency, -Exposure))
 predict(model, newdata)
 }, weights = Exposure, lambda = 0.01, link = "log")

Link: log

Uninterpreted Variation Ratio:
 working response 
0.071637 0.181511 

Working Residuals:
     Min       1Q   Median       3Q      Max 
-0.51148 -0.04432 -0.00327  0.03899  3.63211 

Encoding:
           main.effect interaction
VehPower     linear(9)   linear(5)
VehAge      linear(18)   linear(5)
DrivAge     linear(25)   linear(5)
LogDensity  linear(25)   linear(5)
VehBrand     factor(6)   factor(6)
VehGas       factor(2)   factor(2)
Region       factor(7)   factor(7)
Code
# evaluate fitted surrogate
pred_mid_lgb = predict(mid_lgb, test, type = "response")

deviance <- mean_poisson_deviance(
  y_true = test$Frequency,
  y_pred = pred_mid_lgb,
  sample_weight = test$Exposure
)
cat("Mean Poisson Deviance:", deviance)
Mean Poisson Deviance: 0.4670902

Model Fidelity

To measure how successfully our surrogate replicates the LightGBM model, we use the R-squared score on the link scale (\(\log\) scale).

Code
# calculate R-squared on testing dataset
R2_mid <- weighted.loss(
  x = log(pred_fit_lgb),
  y = log(pred_mid_lgb),
  w = test$Exposure,
  method = "r2"
)

cat(sprintf("R-squared: %.4f", R2_mid))
R-squared: 0.9295

Feature Effects

Code
par.midr(mfrow = c(2, 4))
mid.plots(mid_lgb, engine = "graphics", ylab = "Main Effect")

A key advantage of {midr} is its ability to isolate interaction effects \(g_{jk}\) from main effects \(g_j\). This is particularly useful to understand the joint impact of two variables (e.g., Region and LogDensity).

Code
grid.arrange(
  nrow = 1, widths = c(3, 2),
  ggmid(mid_lgb, "LogDensity:Region", type = "data",
        data = train[1:1e4, ]) +
    labs(y = NULL, subtitle = "Interaction Effect") +
    theme(legend.position = "bottom"),
  ggmid(mid_lgb, "LogDensity:Region", main.effects = TRUE) +
    labs(y = NULL, subtitle = "Total Effect") +
    scale_y_discrete(labels = NULL) +
    theme(legend.position = "bottom")
)

Code
par.midr(mar = c(1, 0, 1, 0), mfrow = c(1, 2))
persp(mid_lgb, "DrivAge:LogDensity", theta = 45, phi = 40)
persp(mid_lgb, "LogDensity:Region", theta = -45, phi = 40)

Effect Importance

To rank the influence of each component discovered in the LightGBM model, we calculate the Effect Importance, defined as the average absolute contribution.

Code
imp_lgb <- mid.importance(mid_lgb, data = train, max.nsamples = 2000)
grid.arrange(
  nrow = 1, widths = c(4, 3),
  ggmid(imp_lgb, theme = "bluescale@qual", max.nterms = 20) +
    labs(title = "Effect Importance",
         subtitle = "Average absolute effect per feature") +
    theme(legend.position = "none"),
  ggmid(imp_lgb, type = "beeswarm", theme = "mako@div", max.nterms = 20) +
    labs(title = "",
         subtitle = "Distribution of effect per feature") +
    scale_y_discrete(labels = NULL) +
    theme(legend.position = "none")
)

Conditional Expectation

We further explore the model’s behavior using the ICE plot. In the MID framework, the variation in ICE curves for a feature \(j\) is explicitly governed by the interaction terms \(g_{jk}\) identified from the LightGBM model.

Code
ice_lgb_link <- mid.conditional(mid_lgb, type = "link", variable = "DrivAge")
ice_lgb <- mid.conditional(mid_lgb, variable = "DrivAge")
grid.arrange(
  nrow = 1,
  ggmid(ice_lgb_link, var.color = LogDensity) +
    theme(legend.position = "bottom") +
    labs(y = "Linear Predictor",
         title = "Conditional Expectation",
         subtitle = "Change in linear predictor"),
  ggmid(ice_lgb, type = "centered", var.color = LogDensity) +
    theme(legend.position = "bottom") +
    labs(y = "Prediction", title = "",
         subtitle = "Centered change in original scale")
)

Additive Attribution

Finally, we perform an Additive Breakdown for individual predictions. This provides an exact allocation of the LightGBM’s prediction into the terms of our surrogate model.

Code
set.seed(42)
row_ids <- sort(sample(nrow(train), 4))

bd_plots <- lapply(
  row_ids,
  function(idx) {
    res <- mid.breakdown(mid_lgb, train, row = idx)
    label <- paste0("Breakdown for Row ", idx)
    ggmid(res, theme = "shap", max.nterms = 10,
          format.args = list(digits = 4)) +
      labs(x = "Linear Predictor", subtitle = label) +
      theme(legend.position = "none")
  }
)

grid.arrange(grobs = bd_plots)

Conclusion

In this notebook, we have demonstrated how Maximum Interpretation Decomposition (MID) bridges the gap between predictive performance and model transparency. By using the {midr} package, we successfully transformed a complex LightGBM model into a structured, additive representation.

While the surrogate model fidelity may not always be perfect, the crucial advantage lies in our ability to quantify its limitations. Through the uninterpreted variation ratio, we can directly assess the complexity of the black-box model. If the fidelity is lower than expected, it serves as a diagnostic signal that the original model relies on high-order interactions or structural complexities that extend beyond second-order effects.

Knowing the extent of this “unexplained” variance is far more valuable than operating in the dark. It allows actuaries to make informed decisions about whether the additional complexity of a black-box model is justified by its performance, or if a more transparent structure is preferable for regulatory and risk management purposes.

As machine learning models become increasingly prevalent in insurance pricing and reserving, tools like MID will be essential for ensuring that our “black-boxes” remain accountable, reliable, and fundamentally understood.