Skip to contents

This article presents some examples of the interpretation of regression models using midr.

Regression Task

We use a benchmark regression task, originally described in Friedman (1991) and Breiman (1996), and implemented in the mlbench package. The dataset has 10 independent predictor variables x1,x2,...,x10x_1,x_2,...,x_{10} each uniformly distributed on the interval [0,1][0,1], and the response variable yy , generated according to the following formula with disturbance term ϵ𝒩(0,1)\epsilon\ {\sim}\ \mathcal{N}{(0, 1)}. y=10sin(πx1x2)+20(x30.5)2+10x4+5x5+ϵ y=10\sin{(\pi{x_1}{x_2})+20{(x_3-0.5)}^2+10x_4+5x_5+\epsilon}

The following plots show the effect of each predictor variable on the response. For x1x_1 and x2x_2 , the interaction effect is shown by the colored lines: the effect of x1x_1 depends on the value of x2x_2 (pale purple for 0 and dark red for 1) and vice versa.

# benchmark regression task
library(mlbench)
set.seed(42)
train  <- as.data.frame(mlbench.friedman1(n = 500L))
test   <- as.data.frame(mlbench.friedman1(n = 500L))
mtrain <- as.data.frame(mlbench.friedman1(n = 2500L))[, -11L]

For each model type, we fit a regression model using the train data of 500 observations and an interpretative MID surrogate of the target model using the mtrain data of 2500 observations without the response variable. We then evaluate the predictive accuracy of the target model and the interpretative accuracy of the MID surrogate based on the RMSE between the test and model prediction or the two predictions, respectively.

# define utility functions for the following chunks
effect_plots <- function(object) {
  mid.plots(object, terms = paste("x", 1:6, sep = "."))
}

interaction_plot <- function(object) {
  ggmid(object, "x.1:x.2", main.effects = TRUE) +
    ggtitle("interaction effect")
}

ice_theme <- color.theme("smoothrainbow")
ice_plot <- function(object, data = train[1:200, ]) {
  ggmid(mid.conditional(object, "x.1", data = data),
        var.color = x.2, type = "centered", theme = ice_theme) +
    ggtitle("conditional expectation")
}

importance_plot <- function(object) {
  ggmid(mid.importance(object), "heatmap") +
    ggtitle("feature importance")
}

eval_plot <- function(model, mid, data = test, ...) {
  pred <- get.yhat(model, data, ...)
  pred_mid <- get.yhat(mid, data)
  actual <- test$y
  rmse_vs_test <- rmse(pred, actual)
  rmse_vs_mid <-  rmse(pred, pred_mid)
  ggplot() + scale_color_theme("DALEX") +
    geom_point(aes(x = pred, y = actual, col = "vs test")) +
    geom_point(aes(x = pred, y = pred_mid, col = "vs mid")) +
    geom_abline(slope = 1, intercept = 0, col = "black", lty = 2) +
    labs(x = "model-prediction", y = "mid-prediction / test") +
    annotate(
      "text", family = "serif", size = 3,
      x = min(pred) + diff(range(pred)) / 8,
      y = max(actual) - diff(range(actual) / 8),
      label = sprintf("RMSE\nvs test: %.3f\nvs mid: %.3f",
                      rmse_vs_test, rmse_vs_mid)
    ) + ggtitle("prediction / representation accuracy")
}

Additive Models

Linear Model

model <- lm(y ~ ., train)
coef(model)
#> (Intercept)         x.1         x.2         x.3         x.4         x.5 
#>   0.1302510   6.8458545   6.8892805  -0.4403955  10.3264576   4.6735425 
#>         x.6         x.7         x.8         x.9        x.10 
#>   0.5837944   0.2030152  -0.6272202  -0.1722106   0.3453933
mid <- interpret(y ~ .^2, mtrain, model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = mtrain, model = model)
#> 
#> Model Class: lm
#> 
#> Intercept: 14.319
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Regularized GLM

library(glmnet)
model <- glmnet(x = as.matrix(train[, -11]), y = train[, 11])
# prediction with arbitrarily chosen lambda
mid <- interpret(y ~ .^2, train[, -11], model,
                 pred.args = list(s = model$lambda[9]))
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = train[, -11], model = model,
#>  pred.args = list(s = model$lambda[9]))
#> 
#> Model Class: elnet, glmnet
#> 
#> Intercept: 14.417
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

evp <- eval_plot(model, mid, data = test[, -11],
                       s = model$lambda[9])
grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), evp, nrow = 2)

Generalized Additive Model

library(gam)
model <- gam(y ~ s(x.1) + s(x.2) + s(x.3) + s(x.4) + s(x.5) +
             s(x.6) + s(x.7) + s(x.8) + s(x.9) + s(x.10),
             family = gaussian, data = train)
mid <- interpret(y ~ .^2, train, model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = train, model = model)
#> 
#> Model Class: Gam, glm, lm
#> 
#> Intercept: 14.417
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 3.2805e-05
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Neural Network

Single Hidden Layer Network

library(nnet)
set.seed(42)
model <- nnet(y ~ ., train, size = 5, linout = TRUE, maxit = 1e3, trace = FALSE)
mid <- interpret(y ~ .^2, mtrain, model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = mtrain, model = model)
#> 
#> Model Class: nnet.formula, nnet
#> 
#> Intercept: 14.195
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.0028096
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Support Vector Machine

RBF Kernel SVM

library(e1071)
model <- svm(y ~ ., train, kernel = "radial")
mid <- interpret(y ~ .^2, mtrain, model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = mtrain, model = model)
#> 
#> Model Class: svm.formula, svm
#> 
#> Intercept: 14.32
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.0075539
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Tree Based Models

Gradient Boosting Trees

library(xgboost)
params <- list(eta = .1, subsample = .7, max_depth = 5)
set.seed(42)
model <- xgboost(as.matrix(train[, -11]), train[, 11], nrounds = 100,
                 params = params, verbose = 0)
mid <- interpret(y ~ .^2, as.matrix(mtrain), model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = as.matrix(mtrain), model = model)
#> 
#> Model Class: xgb.Booster
#> 
#> Intercept: 14.172
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.011813
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

evp <- eval_plot(model, mid, as.matrix(test[, -11]))
grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), evp, nrow = 2)

Random Forest

library(ranger)
set.seed(42)
model <- ranger(y ~ ., train, mtry = 5)
mid <- interpret(y ~ .^2, mtrain, model)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = mtrain, model = model)
#> 
#> Model Class: ranger
#> 
#> Intercept: 14.27
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.0075668
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Decision Tree

library(rpart)
model <- rpart(y ~ ., train)
# create encoding frames for CART
frm <- cbind(model$frame, labels(model, collapse = FALSE))
print(t(frm[frm$var != "<leaf>", c("var", "ltemp")]))
#>       1          2        5         11         22         3          7         
#> var   "x.4"      "x.1"    "x.2"     "x.4"      "x.5"      "x.2"      "x.1"     
#> ltemp "< 0.5579" "< 0.21" "< 0.311" "< 0.2953" "< 0.5849" "< 0.2653" "< 0.3184"
#>       14         15         31         62        
#> var   "x.4"      "x.5"      "x.4"      "x.2"     
#> ltemp "< 0.8843" "< 0.2486" "< 0.8413" "< 0.4782"
frames <- lapply(mtrain, range)
frames$x.1 <- c(frames$x.1, .2100, .3184)
frames$x.2 <- c(frames$x.2, .3110, .2653, .4782)
frames$x.4 <- c(frames$x.4, .5579, .2953, .8843, .8413)
frames$x.5 <- c(frames$x.5, .5849, .2486)
mid <- interpret(y ~ .^2, mtrain, model, type = 0, frames = frames)
print(mid)
#> 
#> Call:
#> interpret(formula = yhat ~ .^2, data = mtrain, model = model,
#>  type = 0, frames = frames)
#> 
#> Model Class: rpart
#> 
#> Intercept: 14.264
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.031256
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)

Other Modes

Predictive MID

model <- mid <- interpret(y ~ .^2, train, lambda = .2)
#> 'model' not passed: response variable in 'data' is used
pred <- pred_mid <- predict(mid, test)
print(mid)
#> 
#> Call:
#> interpret(formula = y ~ .^2, data = train, lambda = 0.2)
#> 
#> Intercept: 14.417
#> 
#> Main Effects:
#> 10 main effect terms
#> 
#> Interactions:
#> 45 interaction terms
#> 
#> Uninterpreted Variation Ratio: 0.031792
grid.arrange(grobs = effect_plots(mid), nrow = 2L)

grid.arrange(interaction_plot(mid), importance_plot(mid),
             ice_plot(mid), eval_plot(model, mid), nrow = 2)