I would like to visualize the mean and percentiles for each variable plotted on kdeplot
The code provided here "kdeplot showing mean and quartiles" does draw the mean and percentiles on the plot, but I would like to do so for a plot that has several variables such as the one displayed by the code below.
sns.kdeplot(data=penguins, x="flipper_length_mm", hue="species", multiple="stack");
In otherwords, is there a way to obtain the transformed flipper_length_mm data used in generating the plot for each of the 3 species?
To get the values that created each of the curves, you can extract the rows for the species (e.g. x = penguins[penguins['species'] == 'Adelie']
). To get the name of the species in the correct order, you could extract their names from the legend (in reverse order, as seaborn first plots the last one).
As the values can contain NaN
s, np.nanmean()
calculates the mean ignoring those. To account for the stacked kde curves, you can store the previous curve and only fill between the previous and the current curve.
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
penguins = sns.load_dataset('penguins')
ax = sns.kdeplot(data=penguins, x='flipper_length_mm', hue='species', multiple='stack', fill=False)
prev_ys = 0
for kdeline, legend_text in zip(ax.lines, ax.legend_.texts[::-1]):
x = penguins[penguins['species'] == legend_text.get_text()]['flipper_length_mm'].values
mean = np.nanmean(x)
std = np.nanstd(x)
xs = kdeline.get_xdata()
ys = kdeline.get_ydata()
prev_height = 0 if np.isscalar(prev_ys) else np.interp(mean, xs, prev_ys)
height = np.interp(mean, xs, ys)
ax.vlines(mean, prev_height, height, color=kdeline.get_color(), ls=':')
ax.fill_between(xs, prev_ys, ys, facecolor=kdeline.get_color(), alpha=0.2)
# filter the region where x within one standard deviation of the mean
sd_filter = (xs >= mean - std) & (xs <= mean + std)
# show this region with a darker color
ax.fill_between(xs[sd_filter], 0 if np.isscalar(prev_ys) else prev_ys[sd_filter], ys[sd_filter],
facecolor=kdeline.get_color(), alpha=0.2)
prev_ys = ys
plt.show()