I am trying to plot split violin plots with Seaborn, i.e. a pair of KDE plots stacked against each other, typically to see the difference between distributions.
My use case is very similar to the docs except I would like to superimpose custom box plots on top (as in this tutorial) However, I am having a strange alignment issue with the violin plots with respect to the X axis and I don't understand what I am doing differently from the docs...
Here's code for a MRE with only split violins:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_theme()
data1 = np.random.normal(0, 1, 1000)
data2 = np.random.normal(1, 2, 1000)
data = pd.concat(
[
pd.DataFrame({"column": "1", "data1": data1, "data2": data2}),
pd.DataFrame({"column": "2", "data3": data2, "data4": data1}),
],
axis="rows",
)
def mkplot():
fig, violin_ax = plt.subplots()
sns.violinplot(
data=data.melt(id_vars="column"),
y="value",
split=True,
hue="variable",
x="column",
ax=violin_ax,
palette="Paired",
bw_method="silverman",
inner=None,
)
plt.show()
mkplot()
This produces split violins whose middle is not aligned with the X axis label: mis-aligned violins
(this is also true when "column"
is of numeric type rather than str
)
It seems that box plots are also mis-aligned, but not with the same magnitude; you can use the function below with the same data
def mkplot2():
fig, violin_ax = plt.subplots()
sns.violinplot(
data=data.melt(id_vars="column"),
y="value",
split=True,
hue="variable",
x="column",
ax=violin_ax,
palette="Paired",
bw_method="silverman",
inner=None,
)
sns.boxplot(
data=data.melt(id_vars="column"),
y="value",
hue="variable",
x="column",
ax=violin_ax,
palette="Paired",
width=0.3,
flierprops={"marker": "o", "markersize": 3},
legend=False,
dodge=True,
)
plt.show()
mkplot2()
How can I solve this ?
The issue is due to the NaNs that you have after melting. This makes 4 groups and thus the violins are shifted to account for those.
You could plot the groups independently:
data_flat = data.melt('column').dropna(subset='value')
violin_ax = plt.subplot()
pal = sns.color_palette('Paired')
for i, (name, g) in enumerate(data_flat.groupby('column')):
sns.violinplot(
data=g,
y='value',
split=True,
hue='variable',
x='column',
ax=violin_ax,
palette=pal[2*i:2*i+2],
bw_method='silverman',
inner=None,
)
Output: