I am trying to customize a Seaborn PairGrid with the following:
I think it can be done by just getting the handles(?), but I am not sure how to do that. This answer is good for JointPlots, but what is the equivalent of ax = g.ax_joint
for PairGrids?
I'd also would like to be able to add a 1:1 identity line without having to define a function as the answer here says, if possible.
import pandas as pd
import numpy as np
import seaborn as sns
np.random.seed(0)
df = pd.DataFrame({'x': np.random.rand(10),
'y': np.random.rand(10),
'z': np.random.rand(10)
})
g = sns.PairGrid(df)
g.map_offdiag(sns.scatterplot)
(Update: using axline
to draw a diagonal line touching the borders, as suggested in the comments. This function is new since matplotlib 3.3.0. Note that for accuracy reasons -- with a log log axis -- axline
still needs a point close to the minimum and another close to the maximum. Those two points also influence the axis limits.)
To access the axes in a 2D way, you can use g.axes[row, col]
. To loop through the axes, you can use for ax in g.axes.flat:
.
You can also use the g.map_...(given_function)
functions. These will call the given_function
for each of the axes, with as first parameter the data column used for x
and as second the one for y
. Optional parameters can be given via g.map_...(given_function, param1=..., ...)
and will be collected in the kwargs
dict. Each time the given_function
is called, the current ax
will be set (so, it is not an extra parameter). You then can use plt.plot
to directly plot on the ax
. Or use ax = plt.gca()
.
Here is some example code tackling your questions. By drawing the diagonal identity line, the x and y limits will be automatically set equal. Note that by default the limits are all "shared" (with only tick labels at the left and lower subplots).
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def update_plot(xdata, ydata, xy_min, xy_max, **kwargs):
plt.yscale('log')
plt.xscale('log')
# plt.plot([xy_min, xy_max], [xy_min, xy_max], color='crimson', linestyle='--', linewidth=2)
plt.axline([xy_min, xy_min], [xy_max, xy_max], color='crimson', linestyle='--', linewidth=2)
plt.grid(which='major', color='navy', lw=1, ls=':')
plt.grid(which='minor', color='navy', lw=0.2, ls=':')
np.random.seed(0)
df = pd.DataFrame({'x': np.random.rand(10),
'y': np.random.rand(10),
'z': np.random.rand(10)})
g = sns.PairGrid(df)
g.map_offdiag(sns.scatterplot)
g.map_offdiag(update_plot, xy_min=df.min().min(), xy_max=df.max().max())
plt.subplots_adjust(left=0.1) # a bit more room at the left for the labels
plt.show()