pythonmatplotliblegendscatter-plot

Create a legend taking into account both the size and color of a scatter plot


I am plotting a dataset using a scatter plot in Python, and I am encoding the data both in color and size. I'd like for the legend to represent this.

I am aware of .legend_elements(prop='sizes') but I can have either colors or sizes but not both at the same time. I found a way of changing the marker color when using prop='sizes' with th color argument, but that's not really what I intend to do (they are all the same color).

Here is a MWE:

import pandas as pd
import numpy as np
import pylab as pl

time = pd.DataFrame(np.random.rand(10))
intensity = pd.DataFrame(np.random.randint(1,5,10))
df = pd.concat([time, intensity], axis=1)

size = intensity.apply(lambda x: 10*x**2)

fig, ax = pl.subplots()
scat = ax.scatter(time, intensity, c=intensity, s=size)

lgd = ax.legend(*scat.legend_elements(prop="sizes", num=3, \
                fmt="{x:.1f}", func=lambda s: np.sqrt(s/10)), \
                title="intensity")

and I'd like to have the markers color-coded too.

Any help or hint would be appreciated!


Solution

  • Using legend_elements, you can get the size and a colour-based legend elements separately, then set the colours of the former with the latter. E.g.,

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as pl
    
    time = pd.DataFrame(np.random.rand(10))
    intensity = pd.DataFrame(np.random.randint(1,5,10))
    df = pd.concat([time, intensity], axis=1)
    
    size = intensity.apply(lambda x: 10*x**2)
    
    fig, ax = pl.subplots()
    scat = ax.scatter(time, intensity, c=intensity, s=size)
    
    # get sized-based legend handles
    size_handles, text = scat.legend_elements(
        prop="sizes",
        num=3,
        fmt="{x:.1f}",
        func=lambda s: np.sqrt(s/10)
    )
    
    # get colour-based legend handles
    colors = [c.get_color() for c in scat.legend_elements(prop="colors", num=3)[0]]
    
    # set colours of the size-based legend handles
    for i, c in enumerate(colors):
        size_handles[i].set_color(c)
    
    # add the legend
    lgd = ax.legend(size_handles, text, title="intensity")