I'd like to add the shap.plots.bar
(https://github.com/slundberg/shap) figure to a subplot. Something like this...
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
model = xgboost.XGBRegressor().fit(X, y)
explainer = shap.Explainer(model, check_additivity=False)
shap_values = explainer(X, check_additivity=False)
shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ??
plt.show()
However, ax
is undefined for shap.plots.bar
, unlike some other plotting methods such as shap.dependence_plot(..., ax=ax[0, 0], show=False)
. Is there a way to add many bar plots to a subplot?
Looking at the source code, the function does not create it's own figure. So, you can create a figure and then set the desired axis as the current axis using plt.sca
.
Here is how you'd do it using the bar plot sample code from the documentation.
import xgboost
import shap
import matplotlib.pyplot as plt
X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
fig, (ax1, ax2) = plt.subplots(2, 1)
plt.sca(ax2)
shap.plots.bar(shap_values)
fig.tight_layout()
fig.show()
If you really want to have the ax
argument, you'll have to edit the source code to add that option.