# src/midlearn/plotting.py
from __future__ import annotations
import numpy as np
import pandas as pd
import plotnine as p9
from typing import TYPE_CHECKING, Literal
from . import plotting_theme as pt
from . import utils
if TYPE_CHECKING:
from .api import (
MIDRegressor,
MIDExplainer,
MIDImportance,
MIDBreakdown,
MIDConditional
)
[docs]
def plot_effect(
estimator: MIDRegressor | MIDExplainer,
term: str,
style: Literal['effect', 'data'] = 'effect',
theme: str | pt.color_theme | None = None,
intercept: bool = False,
main_effects: bool = False,
data: pd.DataFrame | None = None,
jitter: float = 0.3,
resolution: int | tuple[int, int] = 100,
**kwargs
):
"""Visualize the estimated main or interaction effect of a fitted MID model with plotnine.
This is a porting function for the R function `midr::ggmid.mid()`.
Parameters
----------
estimator : MIDRegressor or MIDExplainer
A fitted MIDRegressor or MIDExplainer object containing the model components.
term : str
The name of the component function (main effect or interaction term) to plot.
style : {'effect', 'data'}, default 'effect'
The plotting style.
'effect' plots the estimated component function as a line or a surface.
'data' plots the specified data points (jittered for factor variables) with MID values represented by color.
theme : str or pt.color_theme or None, default None
The color theme to use for the plot.
intercept : bool, default False
If True, the global intercept term is added to the component function values.
main_effects : bool, default False
If True, main effects are included when plotting two-way interaction terms.
Ignored for single-term plots.
data : pandas.DataFrame or None, default None
The data frame to plot. Required only if `style='data'`.
jitter : float, default 0.3
The amount of jitter to apply to factor variables when `style='data'` is used.
resolution : int or tuple[int, int], default 100
The resolution (number of grid points) for calculating the effect.
If a single integer, it is used for both axes of a 2D interaction plot.
If a tuple (int, int), it specifies the resolution for the first and
second predictor in an interaction, respectively.
**kwargs : dict
Additional keyword arguments passed to the main layer of the plot.
Returns
-------
plotnine.ggplot.ggplot
A plotnine object representing the visualization of the component function.
"""
style = utils.match_arg(style, ['effect', 'data'])
tags = term.split(':')
if style == 'data':
if data is None:
raise ValueError("The 'data' argument is required when style='data'. Please provide the pandas.DataFrame to use for plotting.")
data = data.copy()
terms = [term] + (tags if (len(tags) == 2 and main_effects) else [])
data['mid'] = (
estimator.r_predict(X=data, output_type='terms', terms=terms).sum(axis=1)
+ (estimator.intercept if intercept else 0)
)
if len(tags) == 1:
eff_df = estimator.main_effects(term)
if intercept:
eff_df['mid'] += estimator.intercept
p = p9.ggplot(data=eff_df, mapping=p9.aes(x=term, y='mid'))
enc = estimator._encoding_type(tag=term, order=1)
if style == 'effect':
if enc == 'linear':
p = p + p9.geom_line(**kwargs)
if theme is not None:
p = p + p9.aes(color='mid') + pt.scale_color_theme(theme)
elif enc == 'constant':
xval = eff_df[[f'{term}_min', f'{term}_max']].to_numpy().ravel('C')
yval = np.repeat(eff_df['mid'].to_numpy(), 2)
path_df = pd.DataFrame({term: xval, 'mid': yval})
p += p9.geom_path(data=path_df, **kwargs)
if theme is not None:
p = p + p9.aes(color='mid') + pt.scale_color_theme(theme)
else:
p += p9.geom_col(**kwargs)
if theme is not None:
p = p + p9.aes(fill='mid') + pt.scale_fill_theme(theme)
if style == 'data':
jit = jitter[0] if enc == 'factor' else 0
p += p9.geom_jitter(p9.aes(y = "mid"), data=data, width=jit, height=0, **kwargs)
if theme is not None:
p = p + p9.aes(color='mid') + pt.scale_color_theme(theme)
elif len(tags) == 2:
xtag, ytag = tags[0], tags[1]
try:
eff_df = estimator.interactions(term)
except KeyError as e:
try:
eff_df = estimator.interactions(f'{ytag}:{xtag}')
except KeyError:
raise e
if intercept:
eff_df['mid'] += estimator.intercept
if main_effects:
eff_df['mid'] += estimator.effect(term=xtag, x=eff_df) + estimator.effect(term=ytag, x=eff_df)
p = p9.ggplot(eff_df, p9.aes(x=xtag, y=ytag))
xenc = estimator._encoding_type(tag=xtag, order=2)
yenc = estimator._encoding_type(tag=ytag, order=2)
if style == 'effect':
xres, yres = (resolution, resolution) if isinstance(resolution, int) else (resolution, resolution)
if xenc == 'factor':
xval = eff_df[xtag].unique()
else:
xmin, xmax = eff_df[f'{xtag}_min'].min(), eff_df[f'{xtag}_max'].max()
xval = np.linspace(xmin, xmax, xres)
if yenc == 'factor':
yval = eff_df[ytag].unique()
else:
ymin, ymax = eff_df[f'{ytag}_min'].min(), eff_df[f'{ytag}_max'].max()
yval = np.linspace(ymin, ymax, yres)
grid_df = pd.DataFrame({
xtag: np.repeat(xval, len(yval)),
ytag: np.tile(yval, len(xval))
})
grid_df['mid'] = estimator.effect(term=term, x=grid_df)
if intercept:
grid_df['mid'] += estimator.intercept
if main_effects:
grid_df['mid'] += estimator.effect(term=xtag, x=grid_df) + estimator.effect(term=ytag, x=grid_df)
p += p9.geom_raster(p9.aes(x=xtag, y=ytag, fill='mid'), data=grid_df)
p += pt.scale_fill_theme(theme if theme is not None else 'midr')
if style == 'data':
xjit = jitter[0] if xenc == 'factor' else 0
yjit = jitter[1] if yenc == 'factor' else 0
p += p9.geom_jitter(
mapping=p9.aes(color='mid'), data=data, width=xjit, height=yjit, **kwargs
)
if theme is not None:
p += pt.scale_color_theme(theme)
else:
p += p9.scale_color_continuous()
return p
[docs]
def plot_importance(
importance: MIDImportance,
style: Literal['barplot', 'heatmap'] = 'barplot',
theme: str | pt.color_theme | None = None,
max_nterms: int | None = 30,
**kwargs
):
"""Visualize the importance scores of the component functions from a fitted MID model with plotnine.
This is a porting function for the R function `midr::ggmid.mid.importance()`.
Parameters
----------
importance : MIDImportance
A fitted :class:`MIDImportance` object containing the component importance scores.
style : {'barplot', 'heatmap'}, default 'barplot'
The plotting style.
'barplot' displays importance as horizontal bars, suitable for a large number of terms.
'heatmap' displays importance in a matrix format, suitable for visualizing main effects and two-way interactions simultaneously.
theme : str or pt.color_theme or None, default None
The color theme to use for the plot.
max_nterms : int or None, default 30
The maximum number of terms to display when `style='barplot'`.
Terms are sorted by importance before truncation. If None, all terms are displayed.
**kwargs : dict
Additional keyword arguments passed to the main layer of the plot.
Returns
-------
plotnine.ggplot.ggplot
A plotnine object representing the visualization of component importance.
"""
style = utils.match_arg(style, ['barplot', 'heatmap'])
imp_df = importance.importance.copy()
if style == 'barplot':
if max_nterms is not None:
imp_df = imp_df.head(max_nterms)
p = (
p9.ggplot(imp_df, p9.aes(x='term', y='importance'))
+ p9.geom_col(**kwargs)
+ p9.coord_flip()
+ p9.labs(x="")
)
if theme is not None:
theme = pt.color_theme(theme)
var_fill = 'order' if theme.type == 'qualitative' else 'importance'
p = p + p9.aes(fill=var_fill) + pt.scale_fill_theme(theme)
elif style == 'heatmap':
terms = imp_df['term'].str.split(':', expand=True)
if terms.shape[1] == 1:
terms.loc[:, 1] = None
terms[1] = terms[1].fillna(terms[0])
df1 = pd.DataFrame({
'x': terms[0], 'y':terms[1], 'importance': imp_df['importance']
})
df2 = pd.DataFrame({
'x': terms[1], 'y':terms[0], 'importance': imp_df['importance']
})
df = pd.concat([df1, df2]).drop_duplicates(ignore_index=True)
all_vars = pd.unique(np.concatenate([terms[0], terms[1]]))
df['x'] = pd.Categorical(df['x'], categories=all_vars)
df['y'] = pd.Categorical(df['y'], categories=all_vars)
p = (
p9.ggplot(df, p9.aes(x='x', y='y', fill='importance'))
+ p9.geom_tile(**kwargs)
+ p9.labs(x="", y="")
)
p += pt.scale_fill_theme(theme if theme is not None else 'grayscale')
return p
[docs]
def plot_breakdown(
breakdown: MIDBreakdown,
style: Literal['waterfall', 'barplot'] = 'waterfall',
theme: str | pt.color_theme | None = None,
max_nterms: int | None = 15,
catchall: str = 'others',
format: tuple[str, str] = ('%t=%v', '%t'),
**kwargs
):
"""Visualize the decomposition of a single prediction into contributions from each component term with plotnine.
This is a porting function for the R function `midr::ggmid.mid.breakdown()`.
Parameters
----------
breakdown : MIDBreakdown
A fitted :class:`MIDBreakdown` object containing the term contributions for a specific data point.
style : {'waterfall', 'barplot'}, default 'waterfall'
The plotting style.
'waterfall' displays contributions as a cascading plot, showing how each term adds to the final prediction, starting from the intercept.
'barplot' displays contributions as simple horizontal bars, relative to zero.
theme : str or pt.color_theme or None, default None
The color theme to use for the plot.
max_nterms : int or None, default 15
The maximum number of terms to display. Terms beyond this limit are
grouped into a single 'catchall' category. If None, all terms are displayed.
catchall : str, default 'others'
The label used for the grouped category when the number of terms exceeds `max_nterms`.
format : tuple[str, str], default ('%t=%v', '%t')
A tuple of two format strings for labeling terms on the y-axis.
The first string is for main effects (e.g., 'term=value'), and the second is for interaction terms (e.g., 'term').
%t is replaced by the term name, and %v is replaced by the predictor value.
**kwargs : dict
Additional keyword arguments passed to the main layer of the plot.
Returns
-------
plotnine.ggplot.ggplot
A plotnine object representing the breakdown visualization.
"""
style = utils.match_arg(style, ['waterfall', 'barplot'])
brk_df = breakdown.breakdown.copy()
if 'value' in brk_df.columns:
def _format_row(row):
_t = str(row['term'])
_v = str(row['value'])
fmt = format[1 if ':' in _t else 0]
return fmt.replace('%t', _t).replace('%v', _v)
brk_df['term'] = brk_df.apply(_format_row, axis=1)
if max_nterms is not None and max_nterms < len(brk_df):
resid = brk_df.iloc[max_nterms - 1:]['mid'].sum()
brk_df = brk_df.head(max_nterms - 1)
catchall_row = pd.DataFrame([{'term': catchall, 'mid': resid}])
brk_df = pd.concat([brk_df, catchall_row], ignore_index=True)
brk_df['term'] = pd.Categorical(
brk_df['term'], categories=brk_df['term'].iloc[::-1]
)
if style == 'waterfall':
intercept = breakdown.intercept
cs = np.cumsum(np.r_[intercept, brk_df['mid']])
brk_df['xmin'], brk_df['xmax'] = cs[:-1], cs[1:]
brk_df['ymin'], brk_df['ymax'] = brk_df['term'].cat.codes + 1 - 0.4, brk_df['term'].cat.codes + 1 + 0.4
brk_df['ymin2'] = (brk_df['ymin'] - 1).clip(lower=brk_df['ymin'].min())
p = (
p9.ggplot(brk_df, p9.aes(y='term'))
+ p9.geom_vline(xintercept=intercept, size=0.5)
+ p9.geom_rect(p9.aes(xmin='xmin', xmax='xmax', ymin='ymin', ymax='ymax'), **kwargs)
+ p9.geom_linerange(p9.aes(x='xmax', ymax='ymax', ymin='ymin2'), size=0.5)
+ p9.labs(x='yhat')
+ p9.scale_y_discrete(name="")
)
elif style == 'barplot':
p = (
p9.ggplot(brk_df, p9.aes(x='term', y='mid'))
+ p9.geom_col(**kwargs)
+ p9.geom_hline(yintercept=0, linetype='dashed', color='#808080')
+ p9.coord_flip()
+ p9.labs(x="")
)
if theme is not None:
theme = pt.color_theme(theme)
if theme.type == 'qualitative':
mid_sign = np.where(brk_df['mid'] > 0, '> 0', '< 0')
p = p + p9.aes(fill=mid_sign) + pt.scale_fill_theme(theme) + p9.labs(fill='mid')
else:
p = p + p9.aes(fill='mid') + pt.scale_fill_theme(theme)
return p
[docs]
def plot_conditional(
conditional: MIDConditional,
style: Literal['ice', 'centered'] = 'ice',
theme: str | pt.color_theme | None = None,
var_color: str | None = None,
dots: bool = True,
reference: int = 0,
**kwargs
):
"""Visualize Individual Conditional Expectation (ICE) plots or Centered ICE (c-ICE) plots with plotnine.
This is a porting function for the R function `midr::ggmid.mid.conditional()`.
Parameters
----------
conditional : MIDConditional
A fitted :class:`MIDConditional` object containing the ICE data.
style : {'ice', 'centered'}, default 'ice'
The plotting style.
'ice' plots raw predicted values against the predictor variable.
'centered' displays the **change in prediction** relative to a `reference` point,
by subtracting the prediction at the `reference` point for each individual observation.
theme : str or pt.color_theme or None, default None
The color theme to use for the line colors.
var_color : str or None, default None
The name of a column (from the original data) to map to the color aesthetic of the ICE lines. This helps visualize heterogeneity.
dots : bool, default True
If True, plots points for the observed (original) predictions for each sample.
reference : int, default 0
The 0-indexed sample point used as the reference prediction for centering when `style='centered'` is used.
**kwargs : dict
Additional keyword arguments passed to the main layer of the plot.
Returns
-------
plotnine.ggplot.ggplot
A plotnine object representing the conditional expectation visualization.
"""
style = utils.match_arg(style, ['ice', 'centered'])
variable = conditional.variable
obs_df = conditional.observed.copy()
con_df = conditional.conditional.copy()
if style == 'centered':
values = conditional.values
ref = values[min(len(values) - 1, max(0, reference))]
ref_df = con_df.loc[con_df[variable] == ref, ['.id', 'yhat']].rename(columns={'yhat': 'yref'})
obs_df = pd.merge(obs_df, ref_df, on='.id')
con_df = pd.merge(con_df, ref_df, on='.id')
obs_df['centered yhat'] = obs_df['yhat'] - obs_df['yref']
con_df['centered yhat'] = con_df['yhat'] - con_df['yref']
yvar = 'yhat' if style == 'ice' else 'centered yhat'
p = (
p9.ggplot(data=obs_df, mapping=p9.aes(x=variable, y=yvar))
+ p9.geom_line(p9.aes(group='.id'), data=con_df, **kwargs)
)
if dots:
p += p9.geom_point()
if var_color is not None:
p += p9.aes(color=var_color)
if theme is not None:
p += pt.scale_color_theme(theme)
return p