Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Seaborn Jointplot add colors for each class

I want to plot the correlation plot of 2 variables using seaborn jointplot. I have tried a lot of different things but I am not able to add colors to the points according to class.

Here is my code:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

X = np.array([5.2945 , 3.6013 , 3.9675 , 5.1602 , 4.1903 , 4.4995 , 4.5234 ,
              4.6618 , 0.76131, 0.42036, 0.71092, 0.60899, 0.66451, 0.55388,
              0.63863, 0.62504, 0.     , 0.     , 0.49364, 0.44828, 0.43066,
              0.57368, 0.     , 0.     , 0.64824, 0.65166, 0.64968, 0.     ,
              0.     , 0.52522, 0.58259, 1.1309 , 0.     , 0.     , 1.0514 ,
              0.7519 , 0.78745, 0.94873, 1.0169 , 0.     , 0.     , 1.0416 ,
              0.     , 0.     , 0.93648, 0.92801, 0.     , 0.     , 0.89594,
              0.     , 0.80455, 1.0103 ])

y = np.array([ 93, 115, 107, 115, 110, 107, 102, 113,  95, 101, 116,  74, 102,
               102,  78,  85, 108, 110, 109,  80,  91,  88,  99, 110, 108,  96,
               105,  93, 107,  98,  88,  75, 106,  92,  82,  84,  84,  92, 115,
               107,  97, 115,  85, 133, 100,  65,  96, 105, 112, 107, 107, 105])

ax = sns.jointplot(X, y, kind='reg' )
ax.set_axis_labels(xlabel='Brain scores', ylabel='Cognitive scores')
plt.tight_layout()
plt.show()

enter image description here

Now, I want to add colors for each point according to a class variable classes.

like image 341
seralouk Avatar asked Jul 06 '18 12:07

seralouk


People also ask

How do I change my color in Seaborn?

palette{deep, muted, pastel, dark, bright, colorblind} Named seaborn palette to use as the source of colors. Color codes can be set through the high-level seaborn style manager. Color codes can also be set through the function that sets the matplotlib color cycle.

How do you use the color palette in Seaborn?

You can build color palettes using the function sns. color_palette() . This function can take any of the Seaborn built-in palettes (see below). You can also build your own palettes by passing in a list of colors in any valid Matplotlib format, including RGB tuples, hex color codes, or HTML color names.

What is the advantage of using Jointplot to plot data?

Draw a plot of two variables with bivariate and univariate graphs. This function provides a convenient interface to the JointGrid class, with several canned plot kinds. This is intended to be a fairly lightweight wrapper; if you need more flexibility, you should use JointGrid directly.

How do you change markers in Seaborn?

Changing Marker Color on a Scatter Plot Behind the scenes, Seaborn scatter plots use the Matplotlib color styles. Here are the color codes for the basic colors you can use for your scatter plot markers. Pass the value in the argument column to the color parameter to change your marker colors.


2 Answers

The obvious solution is to let the regplot only draw the regression line, but not the points and add those via a usual scatter plot, which has the color c argument.

g = sns.jointplot(X, y, kind='reg', scatter = False )
g.ax_joint.scatter(X,y, c=classes)

enter image description here

like image 127
ImportanceOfBeingErnest Avatar answered Oct 07 '22 20:10

ImportanceOfBeingErnest


I managed to find a solution that is exactly what I need. Thank to @ImportanceOfBeingErnest that gave me the idea to let the regplot only draw the regression line.

Solution:

import pandas as pd

classes = np.array([1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
                    2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 
                    2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 
                    3., 3., 3., 3., 3., 3., 3.])

df = pd.DataFrame(map(list, zip(*[X.T, y.ravel().T])))
df = df.reset_index()
df['index'] = classes[:]

g = sns.jointplot(X, y, kind='reg', scatter = False )
for i, subdata in df.groupby("index"):
    sns.kdeplot(subdata.iloc[:,1], ax=g.ax_marg_x, legend=False)
    sns.kdeplot(subdata.iloc[:,2], ax=g.ax_marg_y, vertical=True, legend=False)
    g.ax_joint.plot(subdata.iloc[:,1], subdata.iloc[:,2], "o", ms = 8)
plt.tight_layout()
plt.show()

enter image description here

like image 3
seralouk Avatar answered Oct 07 '22 20:10

seralouk