# treesデータセットを読み込み、ランダムフォレストモデルを作成する。
library(ranger)
data(trees, package = "datasets")
model_rf <- ranger(Volume~., trees)DALEX
パッケージの概要
DALEX パッケージは、予測モデルを解釈するために開発された可視化手法を統一的な記法で実行するためのパッケージです。DALEX パッケージでは、予測モデルに explain() 関数を適用することで、explainer オブジェクトが作成されます。この explainer オブジェクトに対して、DALEX パッケージのさまざまな関数を適用することで、PDP、ICE、SHAP などのプロットを簡単に作成することができます。
それぞれの手法の詳細については、たとえば、解釈可能な機械学習に関するウェブ書籍 Interpretable Machine Learning(邦訳)などをご参照ください。
ここでは、例として、trees データセットの Volume を Girth と Height から予測するモデルを作成し、そのモデルに対して解釈手法を適用してみましょう。
モデル解釈のためのオブジェクトを作成する
explain() 関数は、さまざまなパッケージのもとで作成された予測モデルを、DALEX パッケージの他の関数に対応するように加工するための関数です。加工後の予測モデルは explainer オブジェクトと呼ばれます。
library(ggplot2)
library(dplyr)
library(DALEX)
explainer <- explain(model_rf,
data = select(trees, -Volume),
y = trees$Volume,
quietly = TRUE,
verbose = FALSE)個別の予測における特徴量と予測値の関係を解釈する
explainer オブジェクトに predict_profile() 関数を適用すると、ICE(Individual Conditional Expectation)プロットを作図することができます。ICE プロットは、注目している特徴量の値だけが違っていた場合に予測値がどのように変化するかを、個々の予測ごとに可視化するものです。
ice <- explainer %>% predict_profile(new_observation = trees)
plot(ice)
モデルにおける特徴量と予測値の関係を解釈する
explainer オブジェクトにmodel_profile() 関数を適用すると、PD(Partial Dependence)プロットを作図することができます。PD プロットは、データ全体の ICE プロットを平均したものにほかならず、注目している特徴量の値が変化したときに予測値が平均的にどのように変化するかを表していると解釈できます。
pdp <- explainer %>% model_profile()
plot(pdp, geom = 'profiles') + theme_gray()
また、model_profile() 関数や plot() 関数用のメソッドの引数 geom を調整することで、ALE(Accumulated Local Effects)プロットを作成したり、実際のデータ点を表示したりするほか、さまざまな変更を加えることが可能です。
ale <- explainer %>% model_profile(type = "accumulated")
plot(ale, geom = 'points') + theme_bw()
個別の予測における特徴量の寄与を解釈する
explainer オブジェクトに predict_parts() 関数を適用すると、SHAP(SHapley Additive exPlanation)プロットを作図することができます。SHAP は、個別の予測値と平均的な予測値との差を、ゲーム理論的手法によって特徴量ごとの寄与に分解したものです。ここでは、5 番目のインスタンスに対する予測への特徴量ごとの寄与を表示してみます。
shap <- explainer %>% predict_parts(trees[5,], type = 'shap')
plot(shap) + theme_light() + theme(legend.position = 'null')
SHAPをウォーターフォール図として描く
predict_parts() 関数が出力したオブジェクト(predict_parts オブジェクト)を shapviz パ ッケージの shapviz() 関数で shapviz オブジェクトに変換することで、ウォーターフォール図を描くことも可能です。
library(shapviz)
sv_waterfall(shapviz(shap)) + theme_light()
モデルにおける特徴量の重要度を解釈する
explainer オブジェクトに model_parts() 関数を適用すると、PFI(Permutation Feature Importance)プロットを作図することができます。PFI は、「データの中で特定の特徴量だけをランダムに並び替えたときに、予測精度がどの程度低下するか」をその特徴量の重要度として解釈するものです。
pfi <- explainer %>% model_parts()
plot(pfi) + theme_bw() + theme(legend.position = 'none')
参考資料
DALEX パッケージには、ここで紹介した手法以外にもさまざまな便利な関数が用意されています。以下のウェブ書籍には、それらの手法の説明だけでなく、R と Python の具体的なコード例も紹介されており、大変有用です。
Przemyslaw Biecek and Tomasz Burzykowski, Explanatory Model Analysis Explore, Explain, and Examine Predictive Models. With examples in R and Python. https://ema.drwhy.ai/