使用 SHAP 使機器學習模型變的可解釋

大家好,我是小寒
今天給大家分享機器學習中的一個關鍵概念,SHAP
SHAP 是一種用于解釋機器學習模型輸出的統一框架。它基于博弈論中的 Shapley 值,用來量化每個特征對模型預測結果的貢獻度。幫助我們理解模型為什么做出這樣的預測。
簡單來說,SHAP 計算每個特征在不同特征組合中對預測的邊際貢獻,從而為復雜模型提供透明、可解釋的輸出。

SHAP 的核心原理
SHAP 的理論基礎來源于合作博弈論中的 Shapley 值。
在合作博弈論中,Shapley 值用于公平地分配合作者在合作中所產生的總收益。
SHAP 將這一思想巧妙地應用到機器學習模型的特征貢獻分配上。
- 參與者:對應機器學習中的特征。
- 合作收益:對應模型的預測結果。
- 目標:計算每個特征對預測結果的邊際貢獻,即該特征加入模型后帶來的增益。
在 SHAP 中,模型預測值被視為總收益,而每個特征則被視為一個參與者。SHAP 值就是計算每個特征在所有可能的特征組合中對預測的平均邊際貢獻,從而解釋了為什么模型會做出某個特定的預測。
為什么要用 SHAP
在現實世界的應用中,很多機器學習模型,尤其是復雜的模型(如深度學習、集成樹模型),往往被稱為“黑箱”模型。
這意味著我們知道它們能做出預測,但很難理解它們為什么會做出某個特定的預測。
這種缺乏透明度會帶來許多問題:
- 信任問題:用戶和利益相關者可能不信任模型的決策,尤其是在高風險領域(如醫療診斷、金融信貸)。
- 調試與改進:當模型表現不佳時,我們難以定位問題出在哪里,是數據問題還是模型本身的問題?哪個特征導致了錯誤的預測?
- 公平性與偏見:模型是否基于不公平或有偏見的特征做出了決策?SHAP 可以幫助我們識別潛在的偏見。
- 合規性:在某些行業,解釋模型決策是法規要求。
SHAP 的出現,為解決這些問題提供了強大的工具,它能夠提供:
- 局部可解釋性
解釋單個預測是如何形成的。例如,為什么一個特定的客戶被預測為“流失”?哪些特征導致了這個結果? - 全局可解釋性
理解整個模型的行為。哪些特征對模型的整體預測貢獻最大?特征之間是否存在交互作用?
數學公式

可加性解釋模型
SHAP 提出了一種“可加性解釋模型”的概念,即任何復雜的模型預測都可以被解釋為基線值與特征貢獻的加和

計算 SHAP 值的近似方法
由于直接計算 Shapley 值涉及到遍歷所有可能的特征組合,計算復雜度為 O(2n),這在特征數量較多時會面臨組合爆炸的問題。
因此,SHAP 提出了多種近似算法來提高計算效率
- Kernel SHAP
這是一種模型無關的 SHAP 算法,通過訓練一個加權線性回歸模型來近似 Shapley 值。它使用一個特殊的核函數來給不同的特征組合賦權重,使得與目標預測更相似的組合具有更高的權重。 - Tree SHAP
專為樹模型(如決策樹、隨機森林、XGBoost、LightGBM)設計的優化算法。
Tree SHAP 利用樹模型的結構特性,可以比 Kernel SHAP 更高效、更精確地計算 Shapley 值。 - Deep SHAP
針對深度學習模型設計的算法。它通過反向傳播 Shapley 值來解釋神經網絡的輸出。
案例分享
下面是一個 Python 示例代碼,展示如何用 SHAP 庫來解釋一個簡單的模型預測。
import xgboost
import shap
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
# 1. 加載加州房價數據集
housing = fetch_california_housing()
X, y = housing.data, housing.target
feature_names = housing.feature_names
# 2. 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 訓練XGBoost回歸模型
model = xgboost.XGBRegressor(objective='reg:squarederror', random_state=42)
model.fit(X_train, y_train)
# 4. 計算SHAP值
explainer = shap.Explainer(model)
shap_values = explainer(X_test)
# 5. summary plot(點圖)
shap.summary_plot(shap_values, X_test, feature_names=feature_names)
# 6. summary plot(bar plot)
shap.summary_plot(shap_values, X_test, feature_names=feature_names, plot_type="bar")
# 7. 選擇一個樣本,waterfall plot(局部解釋)
sample_idx = 0
shap.plots.waterfall(shap_values[sample_idx])






























