Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scatter plot with legend colored by group without multiple calls to plt.scatter

pyplot.scatter allows for passing to c= an array that corresponds to groups, which will then color the points based on those groups. However, this seems to not support generating a legend without specifically plotting each group separately.

So, for example, a scatter plot with groups colored can be generated by iterating over the groups and plotting each separately:

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
feats = load_iris()['data']
target = load_iris()['target']

f, ax = plt.subplots(1)
for i in np.unique(target):
    mask = target == i
    plt.scatter(feats[mask, 0], feats[mask, 1], label=i)
ax.legend()

Which generates:

enter image description here

I can achieve a similar looking plot without iterating over each group though:

f, ax = plt.subplots(1)
ax.scatter(feats[:, 0], feats[:, 1], c=np.array(['C0', 'C1', 'C2'])[target])

But I cannot figure out a way to generate a corresponding legend with this second strategy. All of the examples I've come across iterate over the groups, which seems...less than ideal. I know I can manually generate a legend, but again that seems overly cumbersome.

like image 885
dan_g Avatar asked Oct 18 '22 12:10

dan_g


1 Answers

The matplotlib scatter example that addresses this problem also uses a loop, so that is probably the intended usage: https://matplotlib.org/examples/lines_bars_and_markers/scatter_with_legend.html

If your larger goal is to just make plotting and labeling categorical data more straightforward, you should consider Seaborn. This is a similar question to Scatter plots in Pandas/Pyplot: How to plot by category

A way to accomplish your goal is to use pandas with labeled columns. Once you have data in a Pandas Dataframe, you can use Seaborn pairplot to make this sort of plot. (Seaborn also has the iris dataset available as a labeled DataFrame)

import seaborn as sns
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue="species")

enter image description here

If you just want the first two features, you can use

sns.pairplot(x_vars=['sepal_length'], y_vars=['sepal_width'], data=iris, hue="species", size=5)

enter image description here

If you really want to use the sklearn data dict, you can pull that into a dataframe like so:

import pandas as pd
from sklearn.datasets import load_iris
import numpy as np

feats = load_iris()['data'].astype('O')
target = load_iris()['target']
feat_names = load_iris()['feature_names']
target_names = load_iris()['target_names'].astype('O')

sk_df = pd.DataFrame(
    np.hstack([feats,target_names[target][:,np.newaxis]]),
    columns=feat_names+['target',])
sns.pairplot(sk_df, vars=feat_names, hue="target")
like image 96
Bob Baxley Avatar answered Oct 21 '22 01:10

Bob Baxley