pythonmatplotlibvisualizationshap

Issues with Feature Importance Plots Generated by SHAP in Python


I have been using the SHAP library in Python to visualize feature importance for a machine learning model. Here's a snippet of the code I've been using:

explainer = shap.TreeExplainer(trained_model_gbm)
shap_values = explainer.shap_values(x_test_selected)
shap_importance = np.abs(shap_values).mean(axis=0)
importance_df = pd.DataFrame({'features': selected_feature_names,
                              'importance': shap_importance})
importance_df.sort_values(by='importance', ascending=False, inplace=True)
print(importance_df)

shap_exp = shap.Explanation(values=shap_values, base_values=explainer.expected_value, data=x_test_selected,
                            feature_names=selected_feature_names)

shap.plots.beeswarm(shap_exp, max_display=len(selected_feature_names))

shap.plots.bar(shap_exp, max_display=len(selected_feature_names))

However, I've encountered a problem where the generated plots have the feature names cropped and sometimes the images are poorly scaled. How can I resolve this issue? Any insights or suggestions would be greatly appreciated!

Despite my attempts to adjust the figure size using the figsize parameter and modifying the default image settings with plt.show(), I haven't been able to achieve the desired results. The feature names in the plots are still being cropped, and the overall image scaling remains inconsistent.


Solution

  • I've found a solution to my problem with cropped feature names and poorly scaled images in SHAP plots. Although the solution isn't perfect, as it still truncates very long feature names on the y-axis, it significantly improves the overall appearance and readability of the plots.

    Here's the solution:

    # SHAP Explanation
    explainer = shap.TreeExplainer(trained_best_model)
    shap_values = explainer.shap_values(x_test_best)
    shap_importance = np.abs(shap_values).mean(axis=0)
    importance_df = pd.DataFrame({'features': best_features_overall,
                                  'importance': shap_importance})
    importance_df.sort_values(by='importance', ascending=False, inplace=True)
    
    shap_exp = shap.Explanation(values=shap_values, base_values=explainer.expected_value, data=x_test_best,
                                feature_names=best_features_overall)
    
    # SHAP Bee Swarm Plot
    shap.plots.beeswarm(shap_exp, max_display=len(best_features_overall), show=False)
    plt.ylim(-0.5,
             len(best_features_overall) - 0.5)  # Set Y-axis limits to avoid cutting off feature names
    plt.subplots_adjust(left=0.5, right=0.9)  # Adjust left and right margins of the plot
    plt.savefig(os.path.join(save_dir, 'shap_beeswarm.png'))
    plt.close()
    
    # SHAP Bar Plot
    shap.plots.bar(shap_exp, max_display=len(best_features_overall), show=False)
    plt.subplots_adjust(left=0.5, right=0.9)  # Adjust left and right margins of the plot
    plt.savefig(os.path.join(save_dir, 'shap_bar.png'))
    plt.close()
    
    # Adding the line to save feature importance to a .txt file
    with open('importance_of_features.txt', 'a') as file:
        file.write("\n\nSHAP Explanation - Feature Importance\n")
        file.write(importance_df.to_string(index=False))
    

    This solution includes adjustments to the y-axis limits and the subplot margins to better handle the display of feature names in both the beeswarm and bar plots. Additionally, it saves the plots as PNG files and writes the feature importance to a text file for further reference.