matplotlibstreamlitshap

How to cache a plot in streamlit?


I have built a dashboard in streamlit where you can select a client_ID and have SHAP plots displayed (Waterfall and Force plot) to interpret the prediction of credit default for this client.

I also want to display a SHAP summary plot with the whole train dataset. The later does not change every time you make a new prediction, and takes a lot of time to plot, so I want to cache it. I guess the best approach would be to use st.cache but I have not been able to make it.

Here below is the code I have unsuccessfully tried in main.py: I first define the function of which I want to cache the output (fig), then I execute the output in st.pyplot. It works without the st.cache decorator, but as soon as I add it and rerun the app, the function summary_plot_all runs indefinitely

IN:

@st.cache    
def summary_plot_all():
    fig, axes = plt.subplots(nrows=1, ncols=1)
    shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
    prep_train.columns, max_display=50)
    return fig
    
st.pyplot(summary_plot_all())

OUT (displayed in streamlit app)

Running summary_plot_all().

Does anyone know what's wrong or a better way of caching a plot in streamlit ?

version of packages:
streamlit==0.84.1, 
matplotlib==3.4.2, 
shap==0.39.0

Solution

  • Try

    import matplotlib
    
    @st.cache(hash_funcs={matplotlib.figure.Figure: lambda _: None})
    def summary_plot_all():
        fig, axes = plt.subplots(nrows=1, ncols=1)
        shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
        prep_train.columns, max_display=50)
        return fig
    

    Check this streamlit github issue