Skip to contents

interpret() is used to fit a MID model specifically as an interpretable surrogate for black-box predictive models. A fitted MID model consists of a set of component functions, each with up to two variables.

Usage

interpret(object, ...)

# Default S3 method
interpret(
  object,
  x,
  y = NULL,
  weights = NULL,
  pred.fun = get.yhat,
  link = NULL,
  k = c(NA, NA),
  type = c(1L, 1L),
  frames = list(),
  interaction = FALSE,
  terms = NULL,
  singular.ok = FALSE,
  mode = 1L,
  method = NULL,
  lambda = 0,
  kappa = 1e+06,
  na.action = getOption("na.action"),
  verbosity = 1L,
  encoding.digits = 3L,
  use.catchall = FALSE,
  catchall = "(others)",
  max.ncol = 10000L,
  nil = 1e-07,
  tol = 1e-07,
  pred.args = list(),
  ...
)

# S3 method for class 'formula'
interpret(
  formula,
  data = NULL,
  model = NULL,
  pred.fun = get.yhat,
  weights = NULL,
  subset = NULL,
  na.action = getOption("na.action"),
  verbosity = 1L,
  mode = 1L,
  drop.unused.levels = FALSE,
  pred.args = list(),
  ...
)

Arguments

object

a fitted model object to be interpreted.

...

for interpret.default(), optional arguments can be provided, including fit.intercept, interpolate.beta, weighted.norm, and weighted.encoding. Special character aliases are also supported, such as ok for singular.ok and ie for interaction. For interpret.formula(), any arguments to be passed on to interpret.default().

x

a matrix or data.frame of predictor variables to be used in the fitting process. The response variable should not be included.

y

an optional numeric vector of the model predictions or the response variable.

weights

a numeric vector of sample weights for each observation in x.

pred.fun

a function to obtain predictions from a fitted model, where the first argument is for the fitted model and the second argument is for new data. The default is get.yhat().

a character string specifying the link function: one of "logit", "probit", "cauchit", "cloglog", "identity", "log", "sqrt", "1/mu^2", "inverse", "translogit", "transprobit", "identity-logistic" and "identity-gaussian", or an object containing two functions linkfun() and linkinv(). See help(make.link).

k

an integer or integer-valued vector of length two. The maximum number of sample points for each variable. If a vector is passed, k[1L] is used for main effects and k[2L] is used for interactions. If an integer is passed, k is used for main effects and sqrt(k) is used for interactions. If not positive, all unique values are used as sample points.

type

an integer or integer-valued vector of length two. The type of encoding. The effects of quantitative variables are modeled as piecewise linear functions if type is 1, and as step functions if type is 0. If a vector is passed, type[1L] is used for main effects and type[2L] is used for interactions.

frames

a named list of encoding frames ("numeric.frame" or "factor.frame" objects). The encoding frames are used to encode the variable of the corresponding name. If the name begins with "|" or ":", the encoding frame is used only for main effects or interactions, respectively.

interaction

logical. If TRUE and if terms and formula are not supplied, all interactions for each pair of variables are modeled and calculated.

terms

a character vector of term labels specifying the set of component functions to be modeled. If not passed, terms includes all main effects, and all interactions if interaction is TRUE.

singular.ok

logical. If FALSE, a singular fit is an error.

mode

an integer specifying the method of calculation. If mode is 1, the centralization constraints are treated as penalties for the least squares problem. If mode is 2, the constraints are used to reduce the number of free parameters.

method

an integer specifying the method to be used to solve the least squares problem. A non-negative value will be passed to RcppEigen::fastLmPure(). If negative, stats::lm.fit() is used.

lambda

the penalty factor for pseudo smoothing. The default is 0.

kappa

the penalty factor for centering constraints. Used only when mode is 1. The default is 1e+6.

na.action

a function or character string specifying the method of NA handling. The default is "na.omit".

verbosity

the level of verbosity. 0: fatal, 1: warning (default), 2: info or 3: debug.

encoding.digits

an integer. The rounding digits for encoding numeric variables. Used only when type is 1.

use.catchall

logical. If TRUE, less frequent levels of qualitative variables are dropped and replaced by the catchall level.

catchall

a character string specifying the catchall level.

max.ncol

integer. The maximum number of columns of the design matrix.

nil

a threshold for the intercept and coefficients to be treated as zero. The default is 1e-7.

tol

a tolerance for the singular value decomposition. The default is 1e-7.

pred.args

optional parameters other than the fitted model and new data to be passed to pred.fun().

formula

a symbolic description of the MID model to be fit.

data

a data.frame, list or environment containing the variables in formula. If not found in data, the variables are taken from environment(formula).

model

a fitted model object to be interpreted.

subset

an optional vector specifying a subset of observations to be used in the fitting process.

drop.unused.levels

logical. If TRUE, unused levels of factors will be dropped.

Value

interpret() returns a "mid" object with the following components:

weights

a numeric vector of the sample weights.

call

the matched call.

terms

the term labels.

link

a "link-glm" or "link-midr" object containing the link function.

intercept

the intercept.

encoders

a list of variable encoders.

main.effects

a list of data frames representing the main effects.

interacions

a list of data frames representing the interactions.

ratio

the ratio of the sum of squared error between the target model predictions and the fitted MID values, to the sum of squared deviations of the target model predictions.

fitted.matrix

a matrix showing the breakdown of the predictions into the effects of the component functions.

linear.predictors

a numeric vector of the linear predictors.

fitted.values

a numeric vector of the fitted values.

residuals

a numeric vector of the working residuals.

na.action

information about the special handlings of NAs.

Details

interpret() returns a global surrogate model of the target predictive model. The prediction function of this surrogate model is derived from Maximum Interpretation Decomposition (MID) applied to the prediction function of the target model (denoted \(f(\mathbf{x})\)).

The prediction function of the global surrogate model, denoted \(\mathcal{F}(\mathbf{x})\), has the following structure: $$\mathcal{F}(\mathbf{x}) = f_\phi + \sum_{j} f_{j}(x_j) + \sum_{j<k} f_{jk}(x_j, x_k)$$ where \(f_\phi\) is the intercept, \(f_{j}(x_j)\) is the main effect of feature \(j\), and \(f_{jk}(x_j, x_k)\) is the second-order interaction effect between features \(j\) and \(k\).

To ensure the identifiability (uniqueness) of these decomposed components, they are subject to centering constraints during the fitting process. Specifically, each main effect function \(f_j(x_j)\) is constrained such that its average over the data distribution of feature \(X_j\) is zero. Similarly, each second-order interaction effect function \(f_{jk}(x_j, x_k)\) is constrained such that its conditional average over \(X_j\) (for any fixed value \(x_k\)) is zero, and its conditional average over \(X_k\) (for any fixed value \(x_j\)) is also zero.

The surrogate model is fitted using the least squares method, which minimizes the squared error between the predictions of the target model \(f(\mathbf{x})\) and the surrogate model \(\mathcal{F}(\mathbf{x})\) (typically evaluated on a representative dataset).

Examples

# fit a MID model as a surrogate model
data(cars, package = "datasets")
model <- lm(dist ~ I(speed^2) + speed, cars)
mid <- interpret(dist ~ speed, cars, model)
plot(mid, "speed", intercept = TRUE)
points(cars)


# customize the flexibility of a MID model
data(Nile, package = "datasets")
mid <- interpret(x = 1L:100L, y = Nile, k = 100L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)

# reduce the number of knots by setting the 'k' parameter
mid <- interpret(x = 1L:100L, y = Nile, k = 10L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)

# perform a pseudo smoothing by setting the 'lambda' parameter
mid <- interpret(x = 1L:100L, y = Nile, k = 100L, lambda = 100L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)


# fit a MID model as a predictive model
data(airquality, package = "datasets")
mid <- interpret(Ozone ~ .^2, na.omit(airquality), lambda = .4)
#> 'model' not passed: response variable in 'data' is used
plot(mid, "Wind")

plot(mid, "Temp")

plot(mid, "Wind:Temp", theme = "RdBu")

plot(mid, "Wind:Temp", main.effects = TRUE)