pythonmatplotlibplotlegend

Combine multiple line labels in legend


I have data that results in multiple lines being plotted, I want to give these lines a single label in my legend. I think this can be better demonstrated using the example below,

a = np.array([[ 3.57,  1.76,  7.42,  6.52],
              [ 1.57,  1.2 ,  3.02,  6.88],
              [ 2.23,  4.86,  5.12,  2.81],
              [ 4.48,  1.38,  2.14,  0.86],
              [ 6.68,  1.72,  8.56,  3.23]])

plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')

plt.legend(loc='best')

As you can see at Out[23] the plot resulted in 5 distinct lines. The resulting plot looks like this legend of multiple line plot

Is there any way that I can tell the plot method to avoid multiple labels? I don't want to use custom legend (where you specify the label and the line shape all at once) as much as I can.


Solution

  • I'd make a small helper function personally, if i planned on doing it often;

    from matplotlib import pyplot
    import numpy
    
    
    a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],
                     [ 1.57,  1.2 ,  3.02,  6.88],
                     [ 2.23,  4.86,  5.12,  2.81],
                     [ 4.48,  1.38,  2.14,  0.86],
                     [ 6.68,  1.72,  8.56,  3.23]])
    
    
    def plotCollection(ax, xs, ys, *args, **kwargs):
    
      ax.plot(xs,ys, *args, **kwargs)
    
      if "label" in kwargs.keys():
    
        #remove duplicates
        handles, labels = pyplot.gca().get_legend_handles_labels()
        newLabels, newHandles = [], []
        for handle, label in zip(handles, labels):
          if label not in newLabels:
            newLabels.append(label)
            newHandles.append(handle)
    
        pyplot.legend(newHandles, newLabels)
    
    ax = pyplot.subplot(1,1,1)  
    plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
    plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b')
    pyplot.show()
    

    An easier (and IMO clearer) way to remove duplicates (than what you have) from the handles and labels of the legend is this:

    handles, labels = pyplot.gca().get_legend_handles_labels()
    newLabels, newHandles = [], []
    for handle, label in zip(handles, labels):
      if label not in newLabels:
        newLabels.append(label)
        newHandles.append(handle)
    pyplot.legend(newHandles, newLabels)