I have built a dashboard in streamlit where you can select a client_ID and have SHAP plots displayed (Waterfall and Force plot) to interpret the prediction of credit default for this client.
I also want to display a SHAP summary plot with the whole train dataset. The later does not change every time you make a new prediction, and takes a lot of time to plot, so I want to cache it. I guess the best approach would be to use st.cache but I have not been able to make it.
Here below is the code I have unsuccessfully tried in main.py: I first define the function of which I want to cache the output (fig), then I execute the output in st.pyplot. It works without the st.cache decorator, but as soon as I add it and rerun the app, the function summary_plot_all runs indefinitely
IN:
@st.cache
def summary_plot_all():
fig, axes = plt.subplots(nrows=1, ncols=1)
shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values,
prep_train.columns, max_display=50)
return fig
st.pyplot(summary_plot_all())
OUT (displayed in streamlit app)
Running summary_plot_all().
Does anyone know what's wrong or a better way of caching a plot in streamlit ?
version of packages:
streamlit==0.84.1,
matplotlib==3.4.2,
shap==0.39.0
Try
import matplotlib
@st.cache(hash_funcs={matplotlib.figure.Figure: lambda _: None})
def summary_plot_all():
fig, axes = plt.subplots(nrows=1, ncols=1)
shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values,
prep_train.columns, max_display=50)
return fig
Check this streamlit
github issue