pythonmatplotlibplot-annotations

How to annotate a regression line with the proper text rotation


I have the following snippet of code to draw a best-fit line through a collections of points on a graph, and annotate it with the corresponding R2 value:

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

x = 50 * np.random.rand(20) + 50
y = 200 * np.random.rand(20)
plt.plot(x, y, 'o')

# k, n = np.polyfit(x, y, 1)
k, n, r, _, _ = scipy.stats.linregress(x, y)
line = plt.axline((0, n), slope=k, color='blue')
xy = line.get_xydata()
plt.annotate(
    f'$R^2={r**2:.3f}$',
    (xy[0] + xy[-1]) // 2,
    xycoords='axes fraction',
    ha='center', va='center_baseline',
    rotation=k, rotation_mode='anchor',
)

plt.show()

I have tried various different (x,y) pairs, different xycoords and other keyword parameters in annotate but I haven't been able to get the annotation to properly appear where I want it. How do I get the text annotation to appear above the line with proper rotation, located either at the middle point of the line, or at either end?


Solution

  • 1. Annotation coordinates

    We cannot compute the coordinates using xydata here, as axline() just returns dummy xydata (probably due to the way matplotlib internally plots infinite lines):

    print(line.get_xydata())
    # array([[0., 0.],
    #        [1., 1.]])
    

    Instead we can compute the text coordinates based on the xlim():

    xmin, xmax = plt.xlim()
    xtext = (xmin + xmax) // 2
    ytext = k*xtext + n
    

    Note that these are data coordinates, so they should be used with xycoords='data' instead of 'axes fraction'.


    2. Annotation angle

    We cannot compute the angle purely from the line points, as the angle will also depend on the axis limits and figure dimensions (e.g., imagine the required rotation angle in a 6x4 figure vs 2x8 figure).

    Instead we should normalize the calculation to both scales to get the proper visual rotation:

    rs = np.random.RandomState(0)
    x = 50 * rs.rand(20) + 50
    y = 200 * rs.rand(20)
    plt.plot(x, y, 'o')
    
    # save ax and fig scales
    xmin, xmax = plt.xlim()
    ymin, ymax = plt.ylim()
    xfig, yfig = plt.gcf().get_size_inches()
    
    k, n, r, _, _ = scipy.stats.linregress(x, y)
    plt.axline((0, n), slope=k, color='blue')
    
    # restore x and y limits after axline
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    
    # find text coordinates at midpoint of regression line
    xtext = (xmin + xmax) // 2
    ytext = k*xtext + n
    
    # find run and rise of (xtext, ytext) vs (0, n)
    dx = xtext
    dy = ytext - n
    
    # normalize to ax and fig scales
    xnorm = dx * xfig / (xmax - xmin)
    ynorm = dy * yfig / (ymax - ymin)
    
    # find normalized annotation angle in radians
    rotation = np.rad2deg(np.arctan2(ynorm, xnorm))
    
    plt.annotate(
        f'$R^2={r**2:.3f}$',
        (xtext, ytext), xycoords='data',
        ha='center', va='bottom',
        rotation=rotation, rotation_mode='anchor',
    )