How to go about extracting the numerical values for the shap
summary plot so that the data can be viewed in a dataframe
?:
Here is a MWE:
from sklearn.datasets import make_classification
from shap import Explainer, waterfall_plot, Explanation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Generate noisy Data
X, y = make_classification(n_samples=1000,
n_features=50,
n_informative=9,
n_redundant=0,
n_repeated=0,
n_classes=10,
n_clusters_per_class=1,
class_sep=9,
flip_y=0.2,
random_state=17)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = Explainer(model)
sv = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_train, plot_type="bar")
I tried
np.abs(shap_values.values).mean(axis=0)
but I get a shape of (50,10). How do I get just the aggerated value for each feature to then sort for the feature importance?
You've done this:
from sklearn.datasets import make_classification
from shap import Explainer, waterfall_plot, Explanation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from shap import summary_plot
# Generate noisy data
X, y = make_classification(
n_samples=1000,
n_features=50,
n_informative=9,
n_redundant=0,
n_repeated=0,
n_classes=10,
n_clusters_per_class=1,
class_sep=9,
flip_y=0.2,
random_state=17,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = Explainer(model)
sv = explainer.shap_values(X_test)
summary_plot(sv, X_train, plot_type="bar")
Note, you have features 3, 29, 34 and so on at the top.
If you do:
np.abs(sv).shape
(10, 250, 50)
you'll find you've got 10 classes for 250 datapoints for 50 features.
If you aggregate, you'll get everything you need:
aggs = np.abs(sv).mean(1)
aggs.shape
(10, 50)
You can draw it:
sv_df = pd.DataFrame(aggs.T)
sv_df.plot(kind="barh",stacked=True)
And if it still doesn't look familiar, you can rearrange and filter:
sv_df.loc[sv_df.sum(1).sort_values(ascending=True).index[-10:]].plot(
kind="barh", stacked=True
)
Conclusion:
sv_df
are aggregated SHAP values, as in summary plot, arranged as features per row and classes per column.
Does it help?