pythonpandasmatplotlibseabornswarmplot

How to annotate swarmplot points on a categorical axis and labels from a different column


I’m trying to add labels to a few values in my matplotlib/seaborn plot. Not all, just those above a certain value (below, using iris from sklearn, labels for values greater than 3.6 on the x-axis).

Here, from @Scinana, last year is a discussion of doing that when both axes are numeric. But while it includes an accepted answer, I’m having trouble adapting it to my situation. The links provided in the accepted answer don't help either.

The code below works until the last step (the labeling), which throws: 'TypeError: 'FacetGrid' object is not callable'

Additionally, the outliers need to be annotated with values from dfiris['sepal length (cm)'], not just 'outliers'.

import sklearn as sklearn 
from sklearn.datasets import load_iris

dfiris = load_iris()
dfiris = pd.DataFrame(data=dfiris.data, columns=dfiris.feature_names)
dfiris['name'] = np.where(dfiris['sepal width (cm)'] < 3, 'Amy', 'Bruce')  # adding a fake categorical variable 
dfiris['name'] = np.where((dfiris.name != 'Amy') & (dfiris['petal length (cm)'] >= 3.4), 'Charles', dfiris.name) # adding to that fake categorical variable 

a_viz = sns.catplot(x='sepal width (cm)', y= 'name', kind = 'swarm', data=dfiris)
a_viz.fig.set_size_inches(5, 6)
a_viz.fig.subplots_adjust(top=0.81, right=0.86)

for x, y in zip(dfiris['sepal width (cm)'], dfiris['name']):
    if x > 3.6:
        a_viz.text(x, y, 'outlier', horizontalalignment='left', size='medium', color='black')

The following duplicates didn't completely address the issue with adding annotations from a different column, nor how to prevent the annotations from overlapping.


Solution

  • Non-overlapping Annotations

    import seaborn as sns
    
    # load sample data that has text labels
    df = sns.load_dataset('iris')
    
    # plot the DataFrame
    g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)
    
    # there is only one axes for this plot; provide an alias for ease of use
    ax = g.axes[0, 0]
    
    # get the ytick locations for each name
    ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}
    
    # add the ytick locations for each observation
    df['ytick_loc'] = df.species.map(ytick_loc)
    
    # filter the dataframe to only contain the outliers
    outliers = df[df.sepal_width.gt(3.6)].copy()
    
    # convert the column to strings for annotations
    outliers['sepal_length'] = outliers['sepal_length'].astype(str)
    
    # combine all the sepal_length values as a single string for each species and width
    labels = outliers.groupby(['sepal_width', 'ytick_loc']).agg({'sepal_length': '\n'.join}).reset_index()
    
    # iterate through each axes of the FacetGrid with `for ax in g.axes.flat:` or specify the exact axes to use
    for _, (x, y, s) in labels.iterrows():
        ax.text(x + 0.01, y, s=s, horizontalalignment='left', size='medium', color='black', verticalalignment='center', linespacing=1)
    

    enter image description here


    DataFrame Views

    df

       sepal_length  sepal_width  petal_length  petal_width species  ytick_loc
    0           5.1          3.5           1.4          0.2  setosa          0
    1           4.9          3.0           1.4          0.2  setosa          0
    2           4.7          3.2           1.3          0.2  setosa          0
    3           4.6          3.1           1.5          0.2  setosa          0
    4           5.0          3.6           1.4          0.2  setosa          0
    

    outliers

        sepal_length  sepal_width  petal_length  petal_width    species  ytick_loc
    5            5.4          3.9           1.7          0.4     setosa          0
    10           5.4          3.7           1.5          0.2     setosa          0
    14           5.8          4.0           1.2          0.2     setosa          0
    15           5.7          4.4           1.5          0.4     setosa          0
    16           5.4          3.9           1.3          0.4     setosa          0
    18           5.7          3.8           1.7          0.3     setosa          0
    19           5.1          3.8           1.5          0.3     setosa          0
    21           5.1          3.7           1.5          0.4     setosa          0
    32           5.2          4.1           1.5          0.1     setosa          0
    33           5.5          4.2           1.4          0.2     setosa          0
    44           5.1          3.8           1.9          0.4     setosa          0
    46           5.1          3.8           1.6          0.2     setosa          0
    48           5.3          3.7           1.5          0.2     setosa          0
    117          7.7          3.8           6.7          2.2  virginica          2
    131          7.9          3.8           6.4          2.0  virginica          2
    

    labels

       sepal_width  ytick_loc        sepal_length
    0          3.7          0       5.4\n5.1\n5.3
    1          3.8          0  5.7\n5.1\n5.1\n5.1
    2          3.8          2            7.7\n7.9
    3          3.9          0            5.4\n5.4
    4          4.0          0                 5.8
    5          4.1          0                 5.2
    6          4.2          0                 5.5
    7          4.4          0                 5.7
    

    Overlapping Annotations

    import seaborn as sns
    
    # load sample data that has text labels
    df = sns.load_dataset('iris')
    
    # plot the DataFrame
    g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)
    
    # there is only one axes for this plot; provide an alias for ease of use
    ax = g.axes[0, 0]
    
    # get the ytick locations for each name
    ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}
    
    # plot the text annotations
    for x, y, s in zip(df.sepal_width, df.species.map(ytick_loc), df.sepal_length):
        if x > 3.6:
            ax.text(x, y, s, horizontalalignment='left', size='medium', color='k')
    

    enter image description here