pythonpandasmatplotlib

Extracting Data from a Matplotlib Bar Graph Which is Produced by "SHAP" module


Problem Summary:

  1. I am trying to extract data from a Matplotlib Bar Graph. I know this entails using Matplotlib's gca function, but, the examples I have tried thus far have failed.
  2. The figure is produced automatically by the Python machine learning model interpretability module called "Shap". For More details, see: https://github.com/slundberg/shap
  3. After I have extracted the data, I would like to present it as a Pandas Dataframe.

Here is what I have done.

I have a binary classification and model interpretability pipeline which is structured as follows:

forest = RandomForestClassifier(bootstrap=True, class_weight='balanced', criterion='gini', max_depth=100, 
                                max_features='auto', max_leaf_nodes=10, min_impurity_decrease=0.0, 
                                min_impurity_split=None, min_samples_leaf=1, min_samples_split=2,
                                min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=2,
                                oob_score=False, random_state=0, verbose=1, warm_start=False)


forest.fit(x_train, y_train)
explainer = shap.TreeExplainer(forest)
shap_values = explainer.shap_values(x_train)

figure = plt.figure()
shap.summary_plot(shap_values, features=x_train, feature_names=x_train.columns, plot_type="bar")

the figure

The code works perfectly.

I am trying to extract data from the bar graph, "figure". As you can tell from the "figure = plt.figure()", it is a Matplotlib graph, not the default JS graph that is produced by the package.

Once I have the data, I aim to present it in a dataframe with two columns: "Feature", "Shapely_Value". What do I do?


Solution

  • In the documentation: "where the global importance of each feature is taken to be the mean absolute value for that feature over all the given samples."

    Which means: bars_value = np.abs(shap_values).mean(axis=0)