I am currently trying to program a seaborn stripplot which shows a point's column and index in the dataframe when hovered on by the mouse. This raises a few questions:
stripplot.contains()
return?I get that it returns a boolean saying whether the event lies in the container-artist and a dict
giving the labels of the picked data points. But what does this dict
actually look like in the case of a 2D DataFrame?
Thank you for your help.
My current program looks like follows, and is largely taken from this issue:
#Creating the data frame
A = pd.DataFrame(data = np.random.randint(10,size=(5,5)))
fig,ax = plt.subplots()
strp = sns.stripplot(A)
#Creating an empty annotation
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"))
annot.set_visible(False)
#Updating the annotation based on what stripplot.contains() returns
def update_annot(ind):
mX, mY = A[ind["ind"][0]], A[ind["ind"][0]].loc[ [ind["ind"][0]] ]
annot.xy = (mX,mY)
annot.set_text(str(ind["ind"][0])+", "+str(ind["ind"][0]))
annot.get_bbox_patch().set_facecolor("blue")
annot.get_bbox_patch().set_alpha(0.4)
#Linking the annotation update to the event
def hover(event):
vis = annot.get_visible()
#Create the proper annotation if the event occurs within the bounds of the axis
if event.inaxes == ax:
cont, ind = strp.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
#Call the hover function when the mouse moves
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.show()
My guess here is that there is a problem with what the shape I am expecting from the dict
that stripplot.contains
outputs. And since I can not find a way to print it (once the last line is run, nothing print anymore), it is hard for me to know...
Thank you!
Unlike similar matplotlib functions, sns.stripplot(...)
returns the ax
on which the plot has been created. As such, strp.contains(...)
is the same as ax.contains(...)
.
To add annotations to plots, mplcursors
is a handy library. The sel
parameter in its annotation function has following fields:
sel.target
: the x and y position of the element under the cursorsel.artist
: the matplotlib element under the cursor; in the case of stripplot, this is a collection of dots grouped in a PathCollection
. There is one such collection per x-value.sel.index
: the index into the selected artist.The example code below is tested with seaborn 0.13.2 and matplotlib 3.8.3, starting from Seaborn's tips dataset.
collection_to_day
is a dictionary that maps each of the collection to its corresponding x-value. Adapting it to your specific situation might need some tweaks if the x-values aren't of the pd.Categorical
type.
renumbering
is a dictionary that contains an array for each "day". That array maps the index of the dots to the index in the original dataframe. The original dataframe should not contain NaN values for the x or y values of the plot, as those will be filtered out.
import matplotlib.pyplot as plt
import seaborn as sns
import mplcursors
def show_annotation(sel):
day = collection_to_day[sel.artist]
index_in_df = renumbering[day][sel.index]
row = tips.iloc[index_in_df]
txt = f"Day: {row['day']}\nTime: {row['time']}\nTip: {row['tip']}\nBill total: {row['total_bill']}"
sel.annotation.set_text(txt)
tips = sns.load_dataset('tips')
fig, ax = plt.subplots()
sns.stripplot(data=tips, x='day', y='tip', hue='time', palette='turbo', ax=ax)
days = tips['day'].cat.categories
collection_to_day = dict(zip(ax.collections, days))
renumbering = dict()
for day in days:
renumbering[day] = tips[tips['day'] == day].reset_index()['index']
cursor = mplcursors.cursor(hover=True)
cursor.connect('add', show_annotation)
plt.show()