Quick Start

Here’s a basic example of how to use pyramid-learn, or midlearn, to explain a trained LightGBM model, utilizing the familiar scikit-learn API.

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn import set_config

import lightgbm as lgb
import midlearn as mid 

# Set up plotnine theme for clean visualizations
import plotnine as p9
p9.theme_set(p9.theme_bw(base_family='serif'))

# Configure scikit-learn display
set_config(display='text')

1. Train a Black-Box Model

We use the California Housing dataset to train a LightGBM Regressor, which will serve as our black-box model.

# Load and prepare data
housing = fetch_california_housing()
X = pd.DataFrame(housing.data, columns=housing.feature_names)
y = housing.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# Fit a LightGBM regression model
estimator = lgb.LGBMRegressor(random_state=42)
estimator.fit(X_train, y_train)

print(estimator)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000309 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1838
[LightGBM] [Info] Number of data points in the train set: 15480, number of used features: 8
[LightGBM] [Info] Start training from score 2.070349
LGBMRegressor(random_state=42)

2. Create an Explaination Model

We fit the MIDExplainer to the training data to create a globally faithful, interpretable surrogate model (MID).

# Initialize and fit the MID model
explainer = mid.MIDExplainer(
    estimator=estimator,
    interaction=True,
    params_main=100,
    penalty=.05,
    singular_ok=True,
)
explainer.fit(X_train)

print(explainer)
Generating predictions from the estimator...
MIDExplainer(estimator=LGBMRegressor(random_state=42), params_main=100,
             penalty=0.05)
# Check the fidelity of the surrogate model to the original model
print("R-squared score:", explainer.fidelity_score(X_test))

# Visualize the fidelity
p = p9.ggplot() \
    + p9.geom_abline(slope=1, color='gray') \
    + p9.geom_point(p9.aes(estimator.predict(X_test), explainer.predict(X_test)), alpha=0.5, shape=".") \
    + p9.labs(
        x='Prediction (LightGBM Regressor)',
        y='Prediction (Surrogate MID Regressor)',
        title='Surrogate Model Fidelity Check'
    )
p
Generating predictions from the estimator...
R-squared score: 0.9493631147465892
_images/110574123c91f3c0de2dafdd9699015e62b04fcc56434072c99b7f17d0727c01.png

3. Visualize the Explanation Model

The MID model allows for clear visualization of feature importance, individual effects, and local prediction breakdowns.

# Calculate and plot overall feature importance (default bar plot and heatmap)
imp = explainer.importance()
display(
    imp.plot(max_nterms=20) +
    p9.ggtitle("Importance Plot")
)
display(
    imp.plot(style='heatmap') +
    p9.ggtitle("Importance Heatmap")
)
_images/8116148c14714e36c5e370d9a859fbaa24c8147f88a412ee62fd9ac10d2ebf99.png _images/c4da6f49997d495679168688c76f097ef125106e23713012f7ba27b2f12dd27f.png
# Plot the top 3 important main effects (Component Functions)
for i, t in enumerate(imp.terms(interactions=False)[:3]):
    p = (
        explainer.plot(term=t) +
        p9.ggtitle(f"Main Effect of {t}")
    )
    display(p)
_images/be37f29d6cf376f96bd41b093c031f8e81454005d3ccf4d9226867a5ffb650b3.png _images/7d78ea877559ac6797e5b0a2000d8881f7151c6e2709ae9e65fb148a018167ca.png _images/115b82eb13025871fda0f5994af4c8007deb3581a5adf1ccf41b9a65c4d1c404.png
# Plot the interaction of Longitude and Latitude (Component Functions)
display(
    explainer.plot(
        "Longitude:Latitude",
        theme='mako',
        main_effects=True
    ) +
    p9.ggtitle("Total Effect of Longitude and Latitude")
)
display(
    explainer.plot(
        "Longitude:Latitude",
        style='data',
        theme='mako',
        data=X_train,
        main_effects=True
    ) +
    p9.ggtitle("Total Effect of Longitude and Latitude")

)
_images/fc23cccc4cb282f039c83a9ea8f340ac3cdce6dd49cda75561be357d4d545dcb.png _images/e165d1b808539af39ad1f2998e489f97dbbd4dede541eb3e65f54b9a61102c3d.png
# Plot prediction breakdowns for the first three test samples (Local Interpretability)
for i in range(3):
    p = (
        explainer.breakdown(row=i, data=X_test).plot() +
        p9.ggtitle(f"Breakdown Plot for Row ({i})")
    )
    display(p)
_images/3d284e6248d324d79389bc3819debae02693ecfccd390a387147ae64c5f7559c.png _images/1e70a98019a6f4c68be53f5af7aab3ef434f38cb9e17d1c76aefd87ec503b83e.png _images/1761bce5f97ad56ef7f758277693e39c0a2365f3684552e1bdb92dd2724f16f0.png
# Plot individual conditional expectations (ICE) with color encoding
ice = explainer.conditional(
    variable='MedInc',
    data=X_train.head(500)
)
display(
    ice.plot(alpha=.1) +
    p9.ggtitle("ICE Plot of MedInc")
)
display(
    ice.plot(
        style='centered',
        var_color='HouseAge',
        theme='mako'
    ) +
    p9.ggtitle("Centered ICE Plot of MedInc")
)
_images/ee3a9e3536eeefabbcac9170baf7254762ab8bcad67bf9def7f3187cba4ed139.png _images/38cb1b7633d2c32a9dc90cfe59152b523a27dd7e5acf467c6184e76e320f0572.png