Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Changing color and marker of each point using seaborn jointplot

I have this code slightly modified from here :

import seaborn as sns
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

g.set_axis_labels('total bill', 'tip', fontsize=16)

and I get a nice looking plot - However, for my case I need to be able to change the color AND format of each individual point.

I've tried using the keywords, marker, style, and fmt, but I get the error TypeError: jointplot() got an unexpected keyword argument.

What is the correct way to do this? I'd like to avoid calling sns.JointGrid and plotting the data and marginal distributions manually..

like image 376
pbreach Avatar asked Nov 18 '14 23:11

pbreach


People also ask

How do you change markers in Seaborn plot?

Changing Marker Color on a Scatter Plot Behind the scenes, Seaborn scatter plots use the Matplotlib color styles. Here are the color codes for the basic colors you can use for your scatter plot markers. Pass the value in the argument column to the color parameter to change your marker colors.

How do you change the color of your Seaborn?

You can build color palettes using the function sns. color_palette() . This function can take any of the Seaborn built-in palettes (see below). You can also build your own palettes by passing in a list of colors in any valid Matplotlib format, including RGB tuples, hex color codes, or HTML color names.


2 Answers

Solving this problem is almost no different than that from matplotlib (plotting a scatter plot with different markers and colors), except I wanted to keep the marginal distributions:

import seaborn as sns
from itertools import product
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

#Clear the axes containing the scatter plot
g.ax_joint.cla()

#Generate some colors and markers
colors = np.random.random((len(tips),3))
markers = ['x','o','v','^','<']*100

#Plot each individual point separately
for i,row in enumerate(tips.values):
    g.ax_joint.plot(row[0], row[1], color=colors[i], marker=markers[i])

g.set_axis_labels('total bill', 'tip', fontsize=16)

Which gives me this:

enter image description here

The regression line is now gone, but this is all I needed.

like image 75
pbreach Avatar answered Oct 03 '22 20:10

pbreach


The accepted answer is too complicated. plt.sca() can be used to do this in a simpler way:

import matplotlib.pyplot as plt
import seaborn as sns

tips = sns.load_dataset("tips")
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12))


g.ax_joint.cla() # or g.ax_joint.collections[0].set_visible(False), as per mwaskom's comment

# set the current axis to be the joint plot's axis
plt.sca(g.ax_joint)

# plt.scatter takes a 'c' keyword for color
# you can also pass an array of floats and use the 'cmap' keyword to
# convert them into a colormap
plt.scatter(tips.total_bill, tips.tip, c=np.random.random((len(tips), 3)))
like image 26
Max Shron Avatar answered Oct 03 '22 18:10

Max Shron