pythonmatplotlibvisualizationscatter-plotsubplot

How to ensure the abline / identity line is centered 1:1 in each subplot with individual axis limits in Matplotlib?


I have a 2x2 subplot grid, where each subplot contains a scatter plot with different data points. I'm trying to draw a common abline (slope=1, intercept=0) in each subplot to visualize the relationship between the data points. However, due to varying data ranges in each subplot, the abline does not appear centered 1:1 in all subplots.

I want to ensure that the abline is centered 1:1 in each subplot while maintaining individual axis limits for each plot based on the data points within that specific subplot. In other words, I want the abline to pass through the center of each subplot's data points without distorting the data.

Could someone please guide me on how to achieve the correct centering of the abline in each subplot while keeping individual axis limits based on the data points within that subplot?

Thats the code:

timesteps = [185, 159, 53, 2]

def abline(ax, slope, intercept):
    """Plot a line from slope and intercept"""
    x_vals = np.array(ax.get_xlim())
    y_vals = intercept + slope * x_vals
    ax.plot(x_vals, y_vals, 'r--')

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

for i, timestep in enumerate(timesteps):
    mask = np.where(nan_mask[timestep, :, :] == 0)
    data_tmwm_values = data_tmwm[timestep, :, :][mask]
    ds_plot_values = ds_og_red[timestep, :, :][mask]

    row = i // 2  # Integer division to get the row index
    col = i % 2  # Modulo operation to get the column index
    
    ax = axs[row, col]
    ax.scatter(data_tmwm_values, ds_plot_values, s=20)
    ax.set_xlabel('TMWM')
    ax.set_ylabel('Original')
    ax.set_title(f'Scatter Plot (Timestep: {timestep})')

    correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
    r_value = correlation_matrix[0, 1]

    r_squared = r_value ** 2
    abline(ax, 1, 0)
    ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')

plt.tight_layout()
plt.show()

And thats the image:

enter image description here

I have already tried using the get_xlim() and get_ylim() functions to set the axis limits for each subplot, but it doesn't result in a proper centering of the abline.


Solution

  • It seems like you want an identity line, but you are attempting a linear fit instead. A linear fit might still be useful for you since you calculate various correlation metrics and overlay R2.

    The example below shows how to add a linear fit as well as an identity (y=x) line.

    enter image description here

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    timesteps = [185, 159, 53, 2]
    
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    
    for timestep, ax in zip(timesteps, axs.flatten()):
        #Synthetic data
        data_tmwm_values = np.random.randn(200) * 10 + timestep / 2
        ds_plot_values = np.random.randn(200) * 20 + timestep / 2
    
        ax.scatter(data_tmwm_values, ds_plot_values, s=20)
        ax.set_xlabel('TMWM')
        ax.set_ylabel('Original')
        ax.set_title(f'Scatter Plot (Timestep: {timestep})')
    
        correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
        r_value = correlation_matrix[0, 1]
        r_squared = r_value ** 2
        ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')
        
        #Fit a straight line
        slope, intercept = np.polyfit(data_tmwm_values, ds_plot_values, deg=1)
        #Add the line to the plot, preserving the x and y ranges of the data
        x_low, x_high, y_low, y_high = ax.axis() #Get axis limits
        ax.plot([x_low, x_high], slope * np.array([x_low, x_high]) + intercept, 'r--', label='linear fit of data')
        
        #add identity line, in case that is what you wanted
        lim_low = min(x_low, y_low)
        lim_high = max(x_high, y_high)
        ax.plot([lim_low, lim_high], [lim_low, lim_high], '-k', linewidth=2, label='y=x identity line')
        
        #add legend for a plot, to clarify what the lines represent
        if ax is axs[0, 1]: ax.legend(loc='upper right') 
        
        #optional - clip limits to remove some padding
        ax.axis([lim_low, lim_high, lim_low, lim_high])
        
    plt.tight_layout()
    plt.show()