For the sake of mcve, I build the following pairplot
:
from sklearn.datasets import make_blobs
import pandas as pd
from sklearn.cluster import HDBSCAN
import seaborn as sns
import numpy as np ; np.random.seed(0)
centers = 4
data, c = make_blobs(n_samples = 20,
centers = centers,
n_features = 3,
cluster_std = np.random.rand(centers) * 2.5,
random_state = 0)
df = pd.DataFrame(data)
alg = HDBSCAN()
alg.fit(df)
df['Label'] = alg.labels_.astype(str)
g = sns.pairplot(df, hue = 'Label')
Simple pairplot
, shows a few outliers, has an underlying DataFrame
df
.
What I want is for the functionality to show an annotation of df.index
for a point on hovering over it, and to somehow highlight that point in all of the other plots.
I have found the hover-over annotation methodology in this question for the underlying matplotlib.pyplot
objects, but the code there doesn't seem very extensible to a multi-ax
figure
like the pairplot
above.
I have done this with mplcursors
which gives me the labels (but only by including an additional package)
def show_hover_panel(get_text_func=None):
cursor = mplcursors.cursor(hover=2)
if get_text_func:
cursor.connect(
event = "add",
func = lambda sel: sel.annotation.set_text(get_text_func(sel.index)),
)
return cursor
def on_add(index):
print(index)
ix = df.index[index]
#size = np.zeros(df.shape[0])
#size[index] = 1
#g.map_upper(sns.scatterplot, size = size)
#g.map_lower(sns.scatterplot, size = size)
return "{}".format(ix)
show_hover_panel(on_add)
The commented out part of the code is my (very) unsuccessful attempt to make it highlight all the related points. I leave the fairly comical output as an exercise to the reader.
This example shows how to link highlights via mplcursors
, but requires every point be its own artist, which is incompatible with seaborn
.
Is there any smarter way to do a multi-axis highlight, preferably doing it and the multi-axis annotation natively in matplotlib
and seaborn
?
Tested with Seaborn 0.13.2 (and 0.12.2) and matplotlib 3.8.3.
mplcursors is remarkably versatile. For one, the cursor can be connected to elements from different subplots. In the case of the pairplot, we want the scatter dots which are stored in ax.collections[0]
. Provided there are no NaN
values, the dots remain in the same order as in the dataframe. sel.index
is the index into the scatter dot collection, which can be used to indes the dataframe. Also, extra highlighted elements can be added to sel.extras
. That way, they will be automatically removed when a new point is selected.
Note that mplcursors is a very light-weight library, but reproducing its functionality is a huge amount of work. If you don't want to import it, you can also just drop its only python file into your source directory.
The code below starts from the mpg dataset, with the NaN values removed. Colors are chosen to clearly see the highlighted point in the diverse subplots.
import matplotlib.pyplot as plt
import seaborn as sns
import mplcursors
def show_annotation(sel):
row = mpg.iloc[sel.index] # selected row from the dataframe
sel.annotation.set_text(f"{row['name']} ({row.origin} {row.model_year})\nmpg: {row.mpg} hp:{row.horsepower}")
sel.annotation.get_bbox_patch().set(fc="lightsalmon", alpha=0.9)
for ax in g.axes.flat:
if len(ax.collections) > 0:
sel.extras.append(
ax.scatter(*ax.collections[0].get_offsets()[sel.index], ec='red', fc='none', lw=3, s=50))
mpg = sns.load_dataset('mpg').dropna()
g = sns.pairplot(mpg, vars=['mpg', 'horsepower', 'weight', 'model_year'], hue='origin', palette='pastel')
cursor = mplcursors.cursor([ax.collections[0] for ax in g.axes.flat if len(ax.collections) > 0], hover=True)
cursor.connect("add", show_annotation)
plt.show()