I am using XGBoost with SHAP to analyze feature importance in a multiclass classification problem and need help plotting the SHAP summary plots for all classes at once. Currently, I can only generate plots one class at a time.
SHAP version: 0.45.0
Python version: 3.10.12
Here is my code:
import xgboost as xgb
import shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
# Generate synthetic data
X, y = make_classification(n_samples=500, n_features=20, n_informative=4, n_classes=6, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
# Train a XGBoost model for multiclass classification
model = xgb.XGBClassifier(objective="multi:softprob", random_state=42)
model.fit(X_train, y_train)
I then tried to plot the shape values:
# Create a SHAP TreeExplainer
explainer = shap.TreeExplainer(model)
# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)
# Attempt to plot summary for all classes
shap.summary_plot(shap_values, X_test, plot_type="bar")
I got this interaction plot instead:
I remedied the problem with help from this post:
shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar")
which gives a normal bar plot for class 0:
I can then do the same with classes 1, 2, 3, etc.
The question is, how can you make a summary plot for all the classes? I.e., a single plot showing the contribution of a feature to each class?
The issue is that explainer.shap_values(X_test)
will return a 3D DataFrame of shape (rows, features, classes) and to show a bar plot summary_plot(shap_values)
requires shap_values to be a list of (rows, features) where the list is: length = number of classes.
For my own purposes, I used the following function which converts your shap_values into the format that you need:
def shap_values_to_list(shap_values, model):
shap_as_list=[]
for i in range(len(model.classes_)):
shap_as_list.append(shap_values[:,:,i])
return shap_as_list
Then you can do:
shap_as_list = shap_values_to_list(shap_values, model)
shap.summary_plot(shap_as_list, X_test, plot_type="bar")
You can always add feature_names and class_names to the summary_plot if you need. With my own example I went from having the same kind of interaction plot that you did to the following:
Example of shap.summary_plot output using shap_values converted to a list of shap_values