Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compare 1 independent vs many dependent variables using seaborn pairplot in an horizontal plot

Tags:

python

seaborn

The pairplot function from seaborn allows to plot pairwise relationships in a dataset.

According to the documentation (highlight added):

By default, this function will create a grid of Axes such that each variable in data will by shared in the y-axis across a single row and in the x-axis across a single column. The diagonal Axes are treated differently, drawing a plot to show the univariate distribution of the data for the variable in that column.

It is also possible to show a subset of variables or plot different variables on the rows and columns.

I could find only one example of subsetting different variables for rows and columns, here (it's the 6th plot under the Plotting pairwise relationships with PairGrid and pairplot() section). As you can see, it's plotting many independent variables (x_vars) against the same single dependent variable (y_vars) and the results are pretty nice.

I'm trying to do the same plotting a single independent variable against many dependent ones.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

ages = np.random.gamma(6,3, size=50)
data = pd.DataFrame({"age": ages,
                     "weight": 80*ages**2/(ages**2+10**2)*np.random.normal(1,0.2,size=ages.shape),
                     "height": 1.80*ages**5/(ages**5+12**5)*np.random.normal(1,0.2,size=ages.shape),
                     "happiness": (1-ages*0.01*np.random.normal(1,0.3,size=ages.shape))})

pp = sns.pairplot(data=data,
                  x_vars=['age'],
                  y_vars=['weight', 'height', 'happiness'])

The problem is that the subplots get arranged vertically, and I couldn't find a way to change it.

enter image description here

I know that then the tiling structure would not be so neat as the Y axis should be labeled at every subplot. Also, I know I could generate the plots making it by hand with something like this:

fig, axes = plt.subplots(ncols=3)
for i, yvar in enumerate(['weight', 'height', 'happiness']):
    axes[i].scatter(data['age'],data[yvar])

Still, I'm learning to use the seaborn and I find interface very convenient, so I wonder if there's a way. Also, this example is pretty easy, but for more complex datasets seaborn handles for you many more things that would make the raw-matplotlib approach much more complex quite quickly (hue, to start)

like image 804
mgab Avatar asked Aug 12 '15 13:08

mgab


1 Answers

You can achieve what it seems you are looking for by swapping the variable names passed to the x_vars and y_vars parameters. So revisiting the sns.pairplot portion of your code:

pp = sns.pairplot(data=data,
                  y_vars=['age'],
                  x_vars=['weight', 'height', 'happiness'])

Note that all I've done here is swap x_vars for y_vars. The plots should now be displayed horizontally:

enter image description here

The x-axis will now be unique to each plot with a common y-axis determined by the age column.

like image 109
MarieJ Avatar answered Oct 19 '22 17:10

MarieJ