pythonmatplotlibseabornpivot-tableheatmap

How to get single colorbar with shared x- and y-axis for seaborn heatmaps in subplot?


I want to plot multiple confusion matrices in a single plot with a single colorbar and with a shared x- and y-axis. Here is my code I have tried so far

#Calculate the onfusion matrices
predicted_mod1 = df_binary["Model1"]
actual_class = df_binary["Observed"]

out_df_mod1 = pd.DataFrame(np.vstack([predicted_mod1, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod1 = pd.crosstab(out_df_mod1['actual_class'], out_df_mod1['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod2 = df_binary["Model2"]

out_df_mod2 = pd.DataFrame(np.vstack([predicted_mod2, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod2 = pd.crosstab(out_df_mod2['actual_class'], out_df_mod2['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod4 = df_binary["Model4"]

out_df_mod4 = pd.DataFrame(np.vstack([predicted_mod4, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod4 = pd.crosstab(out_df_mod4['actual_class'], out_df_mod4['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod5 = df_binary["Model5"]

out_df_mod5 = pd.DataFrame(np.vstack([predicted_mod5, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod5 = pd.crosstab(out_df_mod5['actual_class'], out_df_mod5['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod6 = df_binary["Model6"]

out_df_mod6 = pd.DataFrame(np.vstack([predicted_mod6, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod6 = pd.crosstab(out_df_mod6['actual_class'], out_df_mod6['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

Now I have plotted these matrices using the following code

fig = plt.figure(figsize=(6, 3), dpi=300)
fig.subplots_adjust(hspace=0.8, wspace=0.6)

ax = fig.add_subplot(2, 3, 1)
sns.heatmap(CF_mod1, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 2)
sns.heatmap(CF_mod2, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 3)
sns.heatmap(CF_mod3, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 4)
sns.heatmap(CF_mod4, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 5)
sns.heatmap(CF_mod5, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 6)
sns.heatmap(CF_mod6, cmap='Blues', annot=True, fmt='d')

plt.show()

enter image description here My expected output is something like the following enter image description here Now how can I have only one single colorbar with a shared x- and y-axis?

Data

Model1,Model2,Model3,Model4,Model5,Model6,Observed
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,Yes,No,Yes,Yes
No,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,Yes,No,No,No,Yes,No
No,Yes,No,No,No,Yes,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,Yes,No,Yes,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No

Solution

  • Following the logic from this answer, you can then loop through the subplots (with sharex=True and sharey=True to remove the ticks from plots not on the edges), plot the data, and remove the ylabel if it's not in the first column and/or xlabel if it's not in the last row. To ensure the right color scale, a global vmin and vmax is computed before the loop.

    nrows = 2
    ncols = 3
    fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True)
    cbar_ax = fig.add_axes([0.91, 0.3, 0.03, 0.4])
    
    data = [CF_mod1, CF_mod2, CF_mod3, CF_mod4, CF_mod5, CF_mod6]
    
    # get global min and max to enforce the same colorscale in all plots
    vmin = min([d.min().min() for d in data])
    vmax = max([d.max().max() for d in data])
    
    for i, (ax, d) in enumerate(zip(axes.flat, data)):
        p = sns.heatmap(d, ax=ax, annot=True,
                        vmin=vmin, vmax=vmax,
                        cmap="Blues", cbar=(i==0), cbar_ax=None if i else cbar_ax)
        # remove ylabel if not in the first column
        if i%ncols:
            ax.set_ylabel("")
        # remove xlabel if not in the last row
        if i//ncols + 1 != nrows:
            ax.set_xlabel("")
    fig.show()
    

    Result:

    For the axes labels, you can also use suplabels and remove the individual axes labels.

    for i, (ax, d) in enumerate(zip(axes.flat, data)):
        p = sns.heatmap(d, ax=ax, annot=True,
                        vmin=vmin, vmax=vmax,
                        cmap="Blues", cbar=False)
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.supxlabel("Predicted")
    fig.supylabel("Actual")
    

    Result:


    Edit: To put a title above each plot simply add ax.set_title to the loop.

    for i, (ax, d) in enumerate(zip(axes.flat, data)):
        p = sns.heatmap(d, ax=ax, annot=True,
                        vmin=vmin, vmax=vmax,
                        cmap="Blues", cbar=(i==0), cbar_ax=None if i else cbar_ax)
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_title(f"Model {i+1}")
    

    Result:


    Edit: To automate the titles, use the dataframe columns.

    
    for i, (ax, d, title) in enumerate(zip(axes.flat, data, df_binary.columns)):
        ...
        ax.set_title(title)