I’m trying to add labels to a few values in my matplotlib/seaborn plot. Not all, just those above a certain value (below, using iris from sklearn, labels for values greater than 3.6 on the x-axis).
Here, from @Scinana, last year is a discussion of doing that when both axes are numeric. But while it includes an accepted answer, I’m having trouble adapting it to my situation. The links provided in the accepted answer don't help either.
The code below works until the last step (the labeling), which throws: 'TypeError: 'FacetGrid' object is not callable'
Additionally, the outliers need to be annotated with values from dfiris['sepal length (cm)']
, not just 'outliers'
.
import sklearn as sklearn
from sklearn.datasets import load_iris
dfiris = load_iris()
dfiris = pd.DataFrame(data=dfiris.data, columns=dfiris.feature_names)
dfiris['name'] = np.where(dfiris['sepal width (cm)'] < 3, 'Amy', 'Bruce') # adding a fake categorical variable
dfiris['name'] = np.where((dfiris.name != 'Amy') & (dfiris['petal length (cm)'] >= 3.4), 'Charles', dfiris.name) # adding to that fake categorical variable
a_viz = sns.catplot(x='sepal width (cm)', y= 'name', kind = 'swarm', data=dfiris)
a_viz.fig.set_size_inches(5, 6)
a_viz.fig.subplots_adjust(top=0.81, right=0.86)
for x, y in zip(dfiris['sepal width (cm)'], dfiris['name']):
if x > 3.6:
a_viz.text(x, y, 'outlier', horizontalalignment='left', size='medium', color='black')
The following duplicates didn't completely address the issue with adding annotations from a different column, nor how to prevent the annotations from overlapping.
swarmplot
, there's no way to distinguish the tick location on the independent axis for each observation, which means the text annotations for each value on the x-axis will overlap.
pandas.DataFrame.groupby
to create the strings to be passed to s=
.import seaborn as sns
# load sample data that has text labels
df = sns.load_dataset('iris')
# plot the DataFrame
g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)
# there is only one axes for this plot; provide an alias for ease of use
ax = g.axes[0, 0]
# get the ytick locations for each name
ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}
# add the ytick locations for each observation
df['ytick_loc'] = df.species.map(ytick_loc)
# filter the dataframe to only contain the outliers
outliers = df[df.sepal_width.gt(3.6)].copy()
# convert the column to strings for annotations
outliers['sepal_length'] = outliers['sepal_length'].astype(str)
# combine all the sepal_length values as a single string for each species and width
labels = outliers.groupby(['sepal_width', 'ytick_loc']).agg({'sepal_length': '\n'.join}).reset_index()
# iterate through each axes of the FacetGrid with `for ax in g.axes.flat:` or specify the exact axes to use
for _, (x, y, s) in labels.iterrows():
ax.text(x + 0.01, y, s=s, horizontalalignment='left', size='medium', color='black', verticalalignment='center', linespacing=1)
df
sepal_length sepal_width petal_length petal_width species ytick_loc
0 5.1 3.5 1.4 0.2 setosa 0
1 4.9 3.0 1.4 0.2 setosa 0
2 4.7 3.2 1.3 0.2 setosa 0
3 4.6 3.1 1.5 0.2 setosa 0
4 5.0 3.6 1.4 0.2 setosa 0
outliers
sepal_length sepal_width petal_length petal_width species ytick_loc
5 5.4 3.9 1.7 0.4 setosa 0
10 5.4 3.7 1.5 0.2 setosa 0
14 5.8 4.0 1.2 0.2 setosa 0
15 5.7 4.4 1.5 0.4 setosa 0
16 5.4 3.9 1.3 0.4 setosa 0
18 5.7 3.8 1.7 0.3 setosa 0
19 5.1 3.8 1.5 0.3 setosa 0
21 5.1 3.7 1.5 0.4 setosa 0
32 5.2 4.1 1.5 0.1 setosa 0
33 5.5 4.2 1.4 0.2 setosa 0
44 5.1 3.8 1.9 0.4 setosa 0
46 5.1 3.8 1.6 0.2 setosa 0
48 5.3 3.7 1.5 0.2 setosa 0
117 7.7 3.8 6.7 2.2 virginica 2
131 7.9 3.8 6.4 2.0 virginica 2
labels
sepal_width ytick_loc sepal_length
0 3.7 0 5.4\n5.1\n5.3
1 3.8 0 5.7\n5.1\n5.1\n5.1
2 3.8 2 7.7\n7.9
3 3.9 0 5.4\n5.4
4 4.0 0 5.8
5 4.1 0 5.2
6 4.2 0 5.5
7 4.4 0 5.7
import seaborn as sns
# load sample data that has text labels
df = sns.load_dataset('iris')
# plot the DataFrame
g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)
# there is only one axes for this plot; provide an alias for ease of use
ax = g.axes[0, 0]
# get the ytick locations for each name
ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}
# plot the text annotations
for x, y, s in zip(df.sepal_width, df.species.map(ytick_loc), df.sepal_length):
if x > 3.6:
ax.text(x, y, s, horizontalalignment='left', size='medium', color='k')