pythonpandasmatplotlibseabornmplcursors

Mouseover annotation/highlight of seaborn `pairplot`


For the sake of mcve, I build the following pairplot:

from sklearn.datasets import make_blobs
import pandas as pd
from sklearn.cluster import HDBSCAN
import seaborn as sns
import numpy as np ; np.random.seed(0)

centers = 4
data, c = make_blobs(n_samples    = 20, 
                     centers      = centers, 
                     n_features   = 3,
                     cluster_std  = np.random.rand(centers) * 2.5,
                     random_state = 0)

df = pd.DataFrame(data)

alg = HDBSCAN()
alg.fit(df)
df['Label'] = alg.labels_.astype(str)

g = sns.pairplot(df, hue = 'Label')

Simple pairplot, shows a few outliers, has an underlying DataFrame df.

What I want is for the functionality to show an annotation of df.index for a point on hovering over it, and to somehow highlight that point in all of the other plots.

I have found the hover-over annotation methodology in this question for the underlying matplotlib.pyplot objects, but the code there doesn't seem very extensible to a multi-ax figure like the pairplot above.

I have done this with mplcursors which gives me the labels (but only by including an additional package)

def show_hover_panel(get_text_func=None):
    cursor = mplcursors.cursor(hover=2)    
    if get_text_func:
        cursor.connect(
                event = "add",
                func  = lambda sel: sel.annotation.set_text(get_text_func(sel.index)),
                       )  
    return cursor


def on_add(index):
    print(index)
    ix = df.index[index]
    #size = np.zeros(df.shape[0])
    #size[index] = 1
    #g.map_upper(sns.scatterplot, size = size)
    #g.map_lower(sns.scatterplot, size = size)
    return "{}".format(ix)

show_hover_panel(on_add)

The commented out part of the code is my (very) unsuccessful attempt to make it highlight all the related points. I leave the fairly comical output as an exercise to the reader.

This example shows how to link highlights via mplcursors, but requires every point be its own artist, which is incompatible with seaborn.

Is there any smarter way to do a multi-axis highlight, preferably doing it and the multi-axis annotation natively in matplotlib and seaborn?


Solution

  • Tested with Seaborn 0.13.2 (and 0.12.2) and matplotlib 3.8.3.

    mplcursors is remarkably versatile. For one, the cursor can be connected to elements from different subplots. In the case of the pairplot, we want the scatter dots which are stored in ax.collections[0]. Provided there are no NaN values, the dots remain in the same order as in the dataframe. sel.index is the index into the scatter dot collection, which can be used to indes the dataframe. Also, extra highlighted elements can be added to sel.extras. That way, they will be automatically removed when a new point is selected.

    Note that mplcursors is a very light-weight library, but reproducing its functionality is a huge amount of work. If you don't want to import it, you can also just drop its only python file into your source directory.

    The code below starts from the mpg dataset, with the NaN values removed. Colors are chosen to clearly see the highlighted point in the diverse subplots.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import mplcursors
    
    def show_annotation(sel):
        row = mpg.iloc[sel.index]  # selected row from the dataframe
        sel.annotation.set_text(f"{row['name']} ({row.origin} {row.model_year})\nmpg: {row.mpg}  hp:{row.horsepower}")
        sel.annotation.get_bbox_patch().set(fc="lightsalmon", alpha=0.9)
        for ax in g.axes.flat:
            if len(ax.collections) > 0:
                sel.extras.append(
                    ax.scatter(*ax.collections[0].get_offsets()[sel.index], ec='red', fc='none', lw=3, s=50))
    
    mpg = sns.load_dataset('mpg').dropna()
    
    g = sns.pairplot(mpg, vars=['mpg', 'horsepower', 'weight', 'model_year'], hue='origin', palette='pastel')
    
    cursor = mplcursors.cursor([ax.collections[0] for ax in g.axes.flat if len(ax.collections) > 0], hover=True)
    cursor.connect("add", show_annotation)
    plt.show()
    

    mplcursors with sns.pairgrid and highlighting in subplots