Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to add a second legend for subgroups in matplotlib scatterplot

I am making a plot using matplotlib which uses a colormap to show different colors for each subgroup within the plot. However for plotting purposes the subgroups are all one set of x/y pairs.

plt.scatter(rs1.x,rs1.y, marker = 'D', color=cmap ,label='data')
plt.plot(rs1.x,rs1.hub_results.predict(), marker = 'x', color = 'g',label = 'Huber Fit')
plt.plot(rs1.ol_x,rs1.ol_y, marker = 'x', color='r', ms=10, mew=2, linestyle = ' ', label='Outliers')

It gives the image shown below. It is giving me the colors as I mapped them so that part is working fine, but I have not been able to figure out how to add a second legend to the plot to show what the meaning of each color is. Apprecaite any guidance on this.

Thanks, Charlie

enter image description here

like image 791
Charlie_M Avatar asked Feb 12 '23 15:02

Charlie_M


1 Answers

Below is an example of how to do this. Basically, you end up making two calls to legend. On the first call, you save the legend that is created to a variable. The second call removes the first legend you created, so afterwards you can manually add it back with the Axes.add_artist function.

import matplotlib.pyplot as plt
import numpy as np

x = np.random.uniform(-1, 1, 4)
y = np.random.uniform(-1, 1, 4)

p1, = plt.plot([1,2,3])
p2, = plt.plot([3,2,1])
l1 = plt.legend([p2, p1], ["line 2", "line 1"], loc='upper left')

p3 = plt.scatter(x[0:2], y[0:2], marker = 'D', color='r')
p4 = plt.scatter(x[2:], y[2:], marker = 'D', color='g')

# This removes l1 from the axes.
plt.legend([p3, p4], ['label', 'label1'], loc='lower right', scatterpoints=1)
# Add l1 as a separate artist to the axes
plt.gca().add_artist(l1)

Two labels

like image 94
tbekolay Avatar answered Feb 15 '23 04:02

tbekolay