pythonmatplotlibsubplotmultiple-axes

Populating matplotlib subplots with figures generated from function


I have found and adapted the following code snippets for generating diagnostic plots for linear regression. This is currently done using the following functions:

def residual_plot(some_values):
    plot_lm_1 = plt.figure(1)    
    plot_lm_1 = sns.residplot()
    plot_lm_1.axes[0].set_title('title')
    plot_lm_1.axes[0].set_xlabel('label')
    plot_lm_1.axes[0].set_ylabel('label')
    plt.show()


def qq_plot(residuals):
    QQ = ProbPlot(residuals)
    plot_lm_2 = QQ.qqplot()    
    plot_lm_2.axes[0].set_title('title')
    plot_lm_2.axes[0].set_xlabel('label')
    plot_lm_2.axes[0].set_ylabel('label')
    plt.show()

which are called with something like:

plot1 = residual_plot(value_set1)
plot2 = qq_plot(value_set1)
plot3 = residual_plot(value_set2)
plot4 = qq_plot(value_set2)

How can I create subplots so that these 4 plots are displayed in a 2x2 grid?
I have tried using:

fig, axes = plt.subplots(2,2)
    axes[0,0].plot1
    axes[0,1].plot2
    axes[1,0].plot3
    axes[1,1].plot4
    plt.show()

but receive the error:

AttributeError: 'AxesSubplot' object has no attribute 'plot1'

Should I set up the axes attributes from within the functions or where else?


Solution

  • You should create a single figure with four subplot axes that will serve as input axes for your custom plot functions, following

    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy.stats import probplot
    
    
    def residual_plot(x, y, axes = None):
        if axes is None:
            fig = plt.figure()
            ax1 = fig.add_subplot(1, 1, 1)
        else:
            ax1 = axes
        p = sns.residplot(x, y, ax = ax1)
        ax1.set_xlabel("Data")
        ax1.set_ylabel("Residual")
        ax1.set_title("Residuals")
        return p
    
    
    def qq_plot(x, axes = None):
        if axes is None:
            fig = plt.figure()
            ax1 = fig.add_subplot(1, 1, 1)
        else:
            ax1 = axes
        p = probplot(x, plot = ax1)
        ax1.set_xlim(-3, 3)
        return p
    
    
    if __name__ == "__main__":
        # Generate data
        x = np.arange(100)
        y = 0.5 * x
        y1 = y + np.random.randn(100)
        y2 = y + np.random.randn(100)
    
        # Initialize figure and axes
        fig = plt.figure(figsize = (8, 8), facecolor = "white")
        ax1 = fig.add_subplot(2, 2, 1)
        ax2 = fig.add_subplot(2, 2, 2)
        ax3 = fig.add_subplot(2, 2, 3)
        ax4 = fig.add_subplot(2, 2, 4)
    
        # Plot data
        p1 = residual_plot(y, y1, ax1)
        p2 = qq_plot(y1, ax2)
        p3 = residual_plot(y, y2, ax3)
        p4 = qq_plot(y2, ax4)
    
        fig.tight_layout()
        fig.show()
    

    I do not know what your ProbPlot function is, so I just took SciPy's one.