Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Conditional color with matplotlib scatter

I have the following Pandas Dataframe, where column a represents a dummy variable:

Panda Dataframe

What I would like to do is to give my markers a cmap='jet' color following the value of column b, except when the value in column a is equal to 1 - in this case I want it to be the color grey.

Any idea how I can do this?

like image 755
MattnDo Avatar asked Feb 25 '17 08:02

MattnDo


People also ask

How to color the data points in Matplotlib scatter plot?

In Matplotlib’s scatter () function, we can color the data points by a variable using “c” argument. The color argument “c” can take A scalar or sequence of n numbers to be mapped to colors using cmap and norm. A 2-D array in which the rows are RGB or RGBA.

How to add x and y axis labels in Matplotlib scatter plot?

We also add x and y-axis labels to the scatter plot made with Matplotlib. In Matplotlib’s scatter () function, we can color the data points by a variable using “c” argument. The color argument “c” can take A scalar or sequence of n numbers to be mapped to colors using cmap and norm.

How to create a scatter plot in Python?

Then we use the scatter () function multiple times, to create a scatter plot. We pass the color argument to the function to set the color manually. To visualize the graph, use show () method. Here we’ll learn to draw a scatter plot with a single color format. We use the parameter c to set the color of the plot and here we’ll set it to red.

Why do we color data by the group/categorical variable in scatter plots?

Often datasets contain multiple quantitative and categorical variables and may be interested in relationship between two quantitative variables with respect to a third categorical variable. And coloring scatter plots by the group/categorical variable will greatly enhance the scatter plot.


1 Answers

You have to mark your values which are equal to one and plot:

import matplotlib.pyplot as plt
import numpy as np

# test data
t = np.linspace(0, 2 * np.pi, 30)
x = np.sin(t)
x[3] = 1
y = np.cos(t)

# indices for 'bad' values
indices = x == 1
# calc colors from jet cmap
cmap = plt.get_cmap('jet')
colors = cmap((y - y.min()) / y.ptp())

# normal values
plt.scatter(t[~indices], x[~indices], c = colors[~indices], cmap = cmap)
# bad values
plt.scatter(t[indices], x[indices], c = 'grey')
plt.show()

enter image description here

Arrays t, x, y represent pandas series.

like image 160
Serenity Avatar answered Oct 07 '22 22:10

Serenity