Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyplot combine multiple line labels in legend

Tags:

I have data that results in multiple lines being plotted, I want to give these lines a single label in my legend. I think this can be better demonstrated using the example below,

a = np.array([[ 3.57,  1.76,  7.42,  6.52],               [ 1.57,  1.2 ,  3.02,  6.88],               [ 2.23,  4.86,  5.12,  2.81],               [ 4.48,  1.38,  2.14,  0.86],               [ 6.68,  1.72,  8.56,  3.23]])  plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')  plt.legend(loc='best') 

As you can see at Out[23] the plot resulted in 5 distinct lines. The resulting plot looks like this legend of multiple line plot

Is there any way that I can tell the plot method to avoid multiple labels? I don't want to use custom legend (where you specify the label and the line shape all at once) as much as I can.

like image 948
hashmuke Avatar asked Oct 13 '14 10:10

hashmuke


People also ask

How do you combine legends in Python?

You can combine the legend items by scooping out the original objects, here via the ax. get_* function to get the labels and the “handles”. You can think of handles as just points/lines/polygons that refer to individual parts of the legend.

How do I create a multi legend in matplotlib?

MatPlotLib with PythonPlace the first legend at the upper-right location. Add artist, i.e., first legend on the current axis. Place the second legend on the current axis at the lower-right location. To display the figure, use show() method.

How do I fill multiple lines in matplotlib?

With the use of the fill_between() function in the Matplotlib library in Python, we can easily fill the color between any multiple lines or any two horizontal curves on a 2D plane.


2 Answers

I'd make a small helper function personally, if i planned on doing it often;

from matplotlib import pyplot import numpy   a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],                  [ 1.57,  1.2 ,  3.02,  6.88],                  [ 2.23,  4.86,  5.12,  2.81],                  [ 4.48,  1.38,  2.14,  0.86],                  [ 6.68,  1.72,  8.56,  3.23]])   def plotCollection(ax, xs, ys, *args, **kwargs):    ax.plot(xs,ys, *args, **kwargs)    if "label" in kwargs.keys():      #remove duplicates     handles, labels = pyplot.gca().get_legend_handles_labels()     newLabels, newHandles = [], []     for handle, label in zip(handles, labels):       if label not in newLabels:         newLabels.append(label)         newHandles.append(handle)      pyplot.legend(newHandles, newLabels)  ax = pyplot.subplot(1,1,1)   plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a') plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b') pyplot.show() 

An easier (and IMO clearer) way to remove duplicates (than what you have) from the handles and labels of the legend is this:

handles, labels = pyplot.gca().get_legend_handles_labels() newLabels, newHandles = [], [] for handle, label in zip(handles, labels):   if label not in newLabels:     newLabels.append(label)     newHandles.append(handle) pyplot.legend(newHandles, newLabels) 
like image 88
will Avatar answered Sep 19 '22 11:09

will


Numpy solution based on will's response above.

import numpy as np import matplotlib.pylab as plt a = np.array([[3.57, 1.76, 7.42, 6.52],               [1.57, 1.20, 3.02, 6.88],               [2.23, 4.86, 5.12, 2.81],               [4.48, 1.38, 2.14, 0.86],               [6.68, 1.72, 8.56, 3.23]])  plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a') handles, labels = plt.gca().get_legend_handles_labels() 

Assuming that equal labels have equal handles, get unique labels and their respective indices, which correspond to handle indices.

labels, ids = np.unique(labels, return_index=True) handles = [handles[i] for i in ids] plt.legend(handles, labels, loc='best') plt.show() 
like image 43
rafaelvalle Avatar answered Sep 20 '22 11:09

rafaelvalle