I have made a simple scatterplot using matplotlib showing data from 2 numerical variables (varA and varB) with colors that I defined with a 3rd categorical string variable (col) containing 10 unique colors (corresponding to another string variable with 10 unique names), all in the same Pandas DataFrame with 100+ rows. Is there an easy way to create a legend for this scatterplot that shows the unique colored dots and their corresponding category names? Or should I somehow group the data and plot each category in a subplot to do this? This is what I have so far:
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
varA = df['A']
varB = df['B']
col = df['Color']
plt.scatter(varA,varB, c=col, alpha=0.8)
plt.legend()
plt.show()
I had to chime in, because I could not accept that I needed a for-loop to accomplish this. It just seems really annoying and unpythonic - especially when I'm not using Pandas. However, after some searching, I found the answer. You just need to import the 'collections' package so that you can access the PathCollections class and specifically, the legend_elements() method. See implementation below:
# imports
import matplotlib.collections
import numpy as np
# create random data and numerical labels
x = np.random.rand(10,2)
y = np.random.randint(4, size=10)
# create list of categories
labels = ['type1', 'type2', 'type3', 'type4']
# plot
fig, ax = plt.subplots()
scatter = ax.scatter(x[:,0], x[:,1], c=y)
handles, _ = scatter.legend_elements(prop="colors", alpha=0.6) # use my own labels
legend1 = ax.legend(handles, labels, loc="upper right")
ax.add_artist(legend1)
plt.show()
scatterplot legend with custom labels
Source:
https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_with_legend.html
https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements
Considering, Color is the column that has all the colors and labels, you can simply do following.
colors = list(df['Color'].unique())
for i in range(0 , len(colors)):
data = df.loc[df['Color'] == colors[i]]
plt.scatter('A', 'B', data=data, color='Color', label=colors[i])
plt.legend()
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With