Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matplotlib scatter plot with color label and legend specified by c option [duplicate]

Tags:

I'd like to make this kind of scatter plot where the points have colors specified by the "c" option and the legend shows the color's meanings.

The data source of mine is like following:

scatter_x = [1,2,3,4,5] scatter_y = [5,4,3,2,1] group = [1,3,2,1,3] # each (x,y) belongs to the group 1, 2, or 3. 

I tried this:

plt.scatter(scatter_x, scatter_y, c=group, label=group) plt.legend() 

Unfortunately, I did not get the legend as expected. How to show the legend properly? I expected there are five rows and each row shows the color and group correspondences.

enter image description here

like image 452
Light Yagmi Avatar asked Oct 29 '17 23:10

Light Yagmi


People also ask

What is C in Pyplot scatter?

c : color, sequence, or sequence of color, optional, default: 'b' The marker color. Possible values: A single color format string.

Can you use Matplotlib to create a colored scatterplot?

Scatter Plot Color by Category using MatplotlibMatplotlib scatter has a parameter c which allows an array-like or a list of colors. The code below defines a colors dictionary to map your Continent colors to the plotting colors.

How do you plot a legend for a scatter plot in Python?

To create a scatter plot with a legend one may use a loop and create one scatter plot per item to appear in the legend and set the label accordingly. The following also demonstrates how transparency of the markers can be adjusted by giving alpha a value between 0 and 1.


2 Answers

As in the example you mentioned, call plt.scatter for each group:

import numpy as np from matplotlib import pyplot as plt  scatter_x = np.array([1,2,3,4,5]) scatter_y = np.array([5,4,3,2,1]) group = np.array([1,3,2,1,3]) cdict = {1: 'red', 2: 'blue', 3: 'green'}  fig, ax = plt.subplots() for g in np.unique(group):     ix = np.where(group == g)     ax.scatter(scatter_x[ix], scatter_y[ix], c = cdict[g], label = g, s = 100) ax.legend() plt.show() 

enter image description here

like image 127
p-robot Avatar answered Oct 07 '22 01:10

p-robot


check this out:

import matplotlib.pyplot as plt import numpy as  np  fig, ax = plt.subplots() scatter_x = np.array([1,2,3,4,5]) scatter_y = np.array([5,4,3,2,1]) group = np.array([1,3,2,1,3]) for g in np.unique(group):     i = np.where(group == g)     ax.scatter(scatter_x[i], scatter_y[i], label=g) ax.legend() plt.show() 
like image 25
HISI Avatar answered Oct 07 '22 00:10

HISI