python-3.xresolutionshap

How to improve the resolution of the shap plot to save


In the code below, when I try to save the shap plot, the saved image is of a very low resolution both in pdf and and png format. Is it possible to increase the resolution of the image to save?

Here is my code [note that it will take about 10mins for the RF to converge to a solution]:

from sklearn.datasets import make_classification
import seaborn as sns
import numpy as np
from matplotlib import pyplot as plt
import pickle
import joblib
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV


f, (ax1,ax2) = plt.subplots(nrows=1, ncols=2,figsize=(20,8))
# Generate noisy Data
X_train,y_train = make_classification(n_samples=2000, 
                          n_features=240, 
                          n_informative=9, 
                          n_redundant=0, 
                          n_repeated=0, 
                          n_classes=10, 
                          n_clusters_per_class=1,
                          class_sep=9,
                          flip_y=0.2,
                          #weights=[0.5,0.5], 
                          random_state=17)

model = RandomForestClassifier()

parameter_space = {
    'n_estimators': [10,50,100],
    'criterion': ['gini', 'entropy'],
    'max_depth': np.linspace(10,50,11),
}

clf = GridSearchCV(model, parameter_space, cv = 5, scoring = "accuracy", verbose = True) # model
my_model = clf.fit(X_train,y_train)
print(f'Best Parameters: {clf.best_params_}')

shap_values = shap.TreeExplainer(clf.best_estimator_).shap_values(X_train)

f = plt.figure()

#shap.summary_plot(shap_values, X_train)
shap.summary_plot(shap_values[6], X_train)

f.savefig("PDF_plots/Test6_plot1.pdf", bbox_inches='tight')
f.savefig("PDF_plots/Test6.png", bbox_inches='tight')

Solution

  • you can find your answer here

    there is a parameter in plt.savefig name dpi:

    plt.savefig(img, dpi=300)
    

    or you can use plt.figure(dpi=1200) before your plt.plot()