pythonmatplotlibmplcursors

Accessing"local data" with mplcursors


I have trouble understanding how mplcursors cursors work. Let me give an example.

import pandas as pd
import matplotlib.pyplot as plt
import mplcursors
%matplotlib qt5

def random_data(rows_count):
    data = []
    for i in range(rows_count):
        row = {}
        row["x"] = np.random.uniform()
        row["y"] = np.random.uniform()
        if (i%2 == 0):
            row["type"] = "sith"
            row["color"] = "red"
        else:
            row["type"] = "jedi"
            row["color"] = "blue"
        data.append(row)
    return pd.DataFrame(data)

data_df = random_data(30)

fig, ax = plt.subplots(figsize=(8,8))
ax = plt.gca()

types = ["jedi","sith"]

for scat_type in types:
    local_data_df = data_df.loc[data_df["type"] == scat_type]
    scat = ax.scatter(local_data_df["x"],
               local_data_df["y"],
               c=local_data_df["color"],
               label=scat_type)
    cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
    @cursor.connect("add")
    def on_add(sel):
        annotation = (local_data_df.iloc[sel.index]["type"]+
                      "\n"+str(local_data_df.iloc[sel.index]["x"])+
                      "\n"+str(local_data_df.iloc[sel.index]["y"]))
        sel.annotation.set(text=annotation)

ax.legend()
plt.title("a battle of Force users")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-1, 2)
plt.ylim(-1, 2)
ax.set_aspect('equal', adjustable='box')
plt.show()

This code is supposed to generate a DataFrame such that each row has random properties x, y, a type which is jedi or sith, and a color which is blue or red, depending on if the row is a jedi or a sith, then scatterplot the jedis in their color, attach to them a cursor, and then scatterplot the siths in their color, and attach to them another cursor, and display a legend box telling the reader that blue points correspond to jedi rows and red ones to sith ones. However, when hovering points, the annotations say that all the points are sith and the coordinates do not look good.

I would like to understand why the code does not do what I would like it to do.

Just to clarify: I call .scatter() for each type (jedi or sith) and then try to attach a cursor to each of the plots because I have tried calling scatter on the whole data_df, but then the .legend() does not display what I want.

I hope that the answer you give me will be enough for me to be able to write a code that displays the jedi and the sith points, shows the right annotations and the right legend box.


Solution

  • There are a lot of strange things going on.

    One of the confusions is that having the variable local_data_df inside a for loop would create a variable that would only be local to one cycle of the loop. Instead, it is just a global variable that gets overridden for each cycle. Similarly, defining the function on_add inside the for loop doesn't make it local. Also on_add will be global and overridden by each cycle of the for loop.

    Another confusion is that the connected function would have access to local variables from another function or loop. Instead, such local variables get inaccessible once the function or loop has finished.

    Further, not that sel.index will not be the index into the dataframe, but into the points of the scatter plot. You can reset the index of the "local df" to have it similar to the way sel.index is ordered.

    To mimic your local variable, you can add extra data to the scat object. E.g. scat.my_data = local_df will add that variable to the global object that contains the scatter element (the PathCollection that contains all information matplotlib needs to represent the scatter points). Although the variable scat gets overridden, there is one PathCollection for each of the calls to ax.scatter. (You can also access these via ax.collections).

    Here is a rewrite of your code, trying to stay as close as possible to the original:

    import pandas as pd
    import matplotlib.pyplot as plt
    import mplcursors
    
    def random_data(rows_count):
        df = pd.DataFrame({'x': np.random.uniform(0, 1, rows_count),
                           'y': np.random.uniform(0, 1, rows_count),
                           'type': np.random.choice(['sith', 'jedi'], rows_count)})
        df['color'] = df['type'].replace({'sith': 'red', 'jedi': 'blue'})
        return df
    
    def on_add(sel):
        local_data_df = sel.artist.my_data
        annotation = (local_data_df.iloc[sel.index]["type"] +
                      "\n" + str(local_data_df.iloc[sel.index]["x"]) +
                      "\n" + str(local_data_df.iloc[sel.index]["y"]))
        sel.annotation.set(text=annotation)
    
    data_df = random_data(30)
    
    fig, ax = plt.subplots(figsize=(8, 8))
    
    types = ["jedi", "sith"]
    
    for scat_type in types:
        local_data_df = data_df.loc[data_df["type"] == scat_type].reset_index() # resetting the index is necessary
        scat = ax.scatter(local_data_df["x"],
                          local_data_df["y"],
                          c=local_data_df["color"],
                          label=scat_type)
        scat.my_data = local_data_df # store the data into the scat object
        cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
        cursor.connect("add", on_add)
    
    ax.legend()
    ax.set_title("a battle of Force users")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_xlim(-1, 2)
    ax.set_ylim(-1, 2)
    ax.set_aspect('equal', adjustable='box')
    plt.show()
    

    mplcursors with "local data"