I have a 2x2 subplot grid, where each subplot contains a scatter plot with different data points. I'm trying to draw a common abline (slope=1, intercept=0) in each subplot to visualize the relationship between the data points. However, due to varying data ranges in each subplot, the abline does not appear centered 1:1 in all subplots.
I want to ensure that the abline is centered 1:1 in each subplot while maintaining individual axis limits for each plot based on the data points within that specific subplot. In other words, I want the abline to pass through the center of each subplot's data points without distorting the data.
Could someone please guide me on how to achieve the correct centering of the abline in each subplot while keeping individual axis limits based on the data points within that subplot?
Thats the code:
timesteps = [185, 159, 53, 2]
def abline(ax, slope, intercept):
"""Plot a line from slope and intercept"""
x_vals = np.array(ax.get_xlim())
y_vals = intercept + slope * x_vals
ax.plot(x_vals, y_vals, 'r--')
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
for i, timestep in enumerate(timesteps):
mask = np.where(nan_mask[timestep, :, :] == 0)
data_tmwm_values = data_tmwm[timestep, :, :][mask]
ds_plot_values = ds_og_red[timestep, :, :][mask]
row = i // 2 # Integer division to get the row index
col = i % 2 # Modulo operation to get the column index
ax = axs[row, col]
ax.scatter(data_tmwm_values, ds_plot_values, s=20)
ax.set_xlabel('TMWM')
ax.set_ylabel('Original')
ax.set_title(f'Scatter Plot (Timestep: {timestep})')
correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
r_value = correlation_matrix[0, 1]
r_squared = r_value ** 2
abline(ax, 1, 0)
ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')
plt.tight_layout()
plt.show()
And thats the image:
I have already tried using the get_xlim() and get_ylim() functions to set the axis limits for each subplot, but it doesn't result in a proper centering of the abline.
It seems like you want an identity line, but you are attempting a linear fit instead. A linear fit might still be useful for you since you calculate various correlation metrics and overlay R2.
The example below shows how to add a linear fit as well as an identity (y=x) line.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
timesteps = [185, 159, 53, 2]
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
for timestep, ax in zip(timesteps, axs.flatten()):
#Synthetic data
data_tmwm_values = np.random.randn(200) * 10 + timestep / 2
ds_plot_values = np.random.randn(200) * 20 + timestep / 2
ax.scatter(data_tmwm_values, ds_plot_values, s=20)
ax.set_xlabel('TMWM')
ax.set_ylabel('Original')
ax.set_title(f'Scatter Plot (Timestep: {timestep})')
correlation_matrix = np.corrcoef(data_tmwm_values, ds_plot_values)
r_value = correlation_matrix[0, 1]
r_squared = r_value ** 2
ax.text(0.05, 0.95, f"R\u00b2 value: {r_squared:.3f}", transform=ax.transAxes, ha='left', va='top')
#Fit a straight line
slope, intercept = np.polyfit(data_tmwm_values, ds_plot_values, deg=1)
#Add the line to the plot, preserving the x and y ranges of the data
x_low, x_high, y_low, y_high = ax.axis() #Get axis limits
ax.plot([x_low, x_high], slope * np.array([x_low, x_high]) + intercept, 'r--', label='linear fit of data')
#add identity line, in case that is what you wanted
lim_low = min(x_low, y_low)
lim_high = max(x_high, y_high)
ax.plot([lim_low, lim_high], [lim_low, lim_high], '-k', linewidth=2, label='y=x identity line')
#add legend for a plot, to clarify what the lines represent
if ax is axs[0, 1]: ax.legend(loc='upper right')
#optional - clip limits to remove some padding
ax.axis([lim_low, lim_high, lim_low, lim_high])
plt.tight_layout()
plt.show()