Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib scatter legend with colors using categorical variable

I have made a simple scatterplot using matplotlib showing data from 2 numerical variables (varA and varB) with colors that I defined with a 3rd categorical string variable (col) containing 10 unique colors (corresponding to another string variable with 10 unique names), all in the same Pandas DataFrame with 100+ rows. Is there an easy way to create a legend for this scatterplot that shows the unique colored dots and their corresponding category names? Or should I somehow group the data and plot each category in a subplot to do this? This is what I have so far:

import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

varA = df['A']
varB = df['B'] 
col = df['Color']

plt.scatter(varA,varB, c=col, alpha=0.8)
plt.legend()

plt.show()
like image 766
Mary Avatar asked Apr 19 '26 21:04

Mary


2 Answers

I had to chime in, because I could not accept that I needed a for-loop to accomplish this. It just seems really annoying and unpythonic - especially when I'm not using Pandas. However, after some searching, I found the answer. You just need to import the 'collections' package so that you can access the PathCollections class and specifically, the legend_elements() method. See implementation below:

# imports
import matplotlib.collections
import numpy as np

# create random data and numerical labels
x = np.random.rand(10,2)
y = np.random.randint(4, size=10)

# create list of categories
labels = ['type1', 'type2', 'type3', 'type4']

# plot
fig, ax = plt.subplots()
scatter = ax.scatter(x[:,0], x[:,1], c=y)
handles, _ = scatter.legend_elements(prop="colors", alpha=0.6) # use my own labels
legend1 = ax.legend(handles, labels, loc="upper right")
ax.add_artist(legend1)
plt.show()

scatterplot legend with custom labels

Source:

https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_with_legend.html

https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements

like image 128
James Avatar answered Apr 22 '26 18:04

James


Considering, Color is the column that has all the colors and labels, you can simply do following.

colors = list(df['Color'].unique())
for i in range(0 , len(colors)):
    data = df.loc[df['Color'] == colors[i]]
    plt.scatter('A', 'B', data=data, color='Color', label=colors[i])
plt.legend()
plt.show()
like image 42
harvpan Avatar answered Apr 22 '26 16:04

harvpan