Problem Summary:
Here is what I have done.
I have a binary classification and model interpretability pipeline which is structured as follows:
forest = RandomForestClassifier(bootstrap=True, class_weight='balanced', criterion='gini', max_depth=100,
max_features='auto', max_leaf_nodes=10, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=2,
oob_score=False, random_state=0, verbose=1, warm_start=False)
forest.fit(x_train, y_train)
explainer = shap.TreeExplainer(forest)
shap_values = explainer.shap_values(x_train)
figure = plt.figure()
shap.summary_plot(shap_values, features=x_train, feature_names=x_train.columns, plot_type="bar")
The code works perfectly.
I am trying to extract data from the bar graph, "figure". As you can tell from the "figure = plt.figure()", it is a Matplotlib graph, not the default JS graph that is produced by the package.
Once I have the data, I aim to present it in a dataframe with two columns: "Feature", "Shapely_Value". What do I do?
In the documentation: "where the global importance of each feature is taken to be the mean absolute value for that feature over all the given samples."
Which means:
bars_value = np.abs(shap_values).mean(axis=0)