Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Drawing lines on scatter with Seaborn

Tags:

draw

seaborn

My objective is to draw a horizontal red line on y = 0 on a plot made using seaborn: sns.lmplotsplitted by col= or row=.

import numpy as np, seaborn as sns, pandas as pd
np.random.seed(5)

myData = pd.DataFrame({'x' :  np.arange(1, 101), 'y': np.random.normal(0, 4, 100),\
'z' : ['a','b']*50, 'w':np.random.poisson(0.15,100)})


sns.lmplot("x", "y", col="z", row="w", data=myData, fit_reg=False)
plt.plot(np.linspace(-20,120,1000), [0]*1000, 'r-')

We can see that only the last plot, of the array of plots, is marked by the red line:

enter image description here

Thanks for your help,

EDIT: reworded the question to account for the case where we generate an array of plots using col= and/or row= and we want the line to be traced on each plot.

like image 782
Alex Fortin Avatar asked Mar 23 '16 23:03

Alex Fortin


People also ask

How do you draw a line in Seaborn?

To draw a line plot in the Seaborn library, the lineplot() function is used. The x attribute of the lineplot() function contains the list of the values to be displayed on the x-axis. For the y-axis, the list of values is passed to the y attribute.

How do you draw a scatter plot using Seaborn?

How do I make a scatter plot in Python? To make a scatter plot in Python you can use Seaborn and the <code>scatterplot()</code> method. For example, if you want to examine the relationship between the variables “Y” and “X” you can run the following code: <code>sns. scatterplot(Y, X, data=dataframe)</code>.

How do you graph multiple lines in Seaborn?

You probably need to re-organize your dataframe in a suitable way so that there is one column for the x data, one for the y data, and one which holds the label for the data point. You can also just use matplotlib. pyplot . If you import seaborn , much of the improved design is also used for "regular" matplotlib plots.

How do you insert a horizontal line in Seaborn?

Seaborn's refline() function to add horizontal/vertical lines in subplots. To add a horizontal and vertical line we can use Seaborn's refline() function with x and y y co-ordinates for the locations of the horizontal and vertical lines.


3 Answers

So this chunk of code works for the general case where we use col=, row=, and hue=.

import numpy as np, seaborn as sns, pandas as pd
np.random.seed(5)

myData = pd.DataFrame({'x' :  np.arange(1, 101), 'y': np.random.normal(0, 4, 100),\
'z' : ['a','b']*50, 'w':np.random.poisson(0.15,100), 'hueMe':['q','w','e','r','t']*20})

myPlot = sns.FacetGrid(col="z", row='w', hue='hueMe', data=myData, size=5)
myPlot = myPlot.map(plt.scatter, "x", "y").set(xlim=(-20,120) , ylim=(-15,15))
myPlot = myPlot.map_dataframe(plt.plot, [-20,120], [0,0], 'r-').add_legend().set_axis_labels("x", "y")
plt.show()

enter image description here

Not sure why the color of the horizontal line comes out as the last color used on each individual plot, but I give up on this for now :)

like image 61
Alex Fortin Avatar answered Oct 16 '22 03:10

Alex Fortin


Since I came across this looking for an answer, here is a more general answer that I eventually discovered:

map_dataframe will also accept a user defined function (and passes the data frame to this function) which is quite powerful because you can plot anything onto the facetgrid. In the OP case:

def plot_hline(y,**kwargs):
    data = kwargs.pop("data") #get the data frame from the kwargs
    plt.axhline(y=y, c='red',linestyle='dashed',zorder=-1) #zorder places the line underneath the other points

myPlot = sns.FacetGrid(col="z", row='w', hue='hueMe', data=myData, size=5)
myPlot.map(plt.scatter, "x", "y").set(xlim=(-20,120) , ylim=(-15,15))
myPlot.map_dataframe(plot_hline,y=0)
plt.show()

My problem was slightly more complex because I wanted a different horizontal line on each facet.

To replicate my case, assume the 'z' variable has two samples (a and b) and each with an observed value 'obs' (which I've added to myData below). 'hueMe' represents modeled values for each sample.

myData = pd.DataFrame({'x' :  np.arange(1, 101), 
                       'y': np.random.normal(0, 4, 100),
                       'z' : ['a','b']*50,
                       'w':np.random.poisson(0.15,100),
                       'hueMe':['q','w','e','r','t']*20,
                       'obs':[3,2]*50})

When you pass the data frame to plot_hline, you need to drop the duplicate values of 'obs' for each 'z' sample because axhline can only take a single value for y. (remember in our case each sample has 1 observed value 'obs' but multiple modeled 'hueMe' values). further, y must be a scalar (rather than a series) so you need to index into the data frame to extract the value itself.

def plot_hline(y,z, **kwargs):
    data = kwargs.pop("data") #the data passed in through kwargs is a subset of the original data - only the subset for the row and col being plotted. it's a for loop in disguise.
    data = data.drop_duplicates([z]) #drop the duplicate rows
    yval = data[y].iloc[0] #extract the value for your hline.
    plt.axhline(y=yval, c='red',linestyle='dashed',zorder=-1)


myPlot = sns.FacetGrid(col="z", row='w', hue='hueMe', data=myData, size=5)
myPlot.map(plt.scatter, "x", "y").set(xlim=(-20,120) , ylim=(-15,15))
myPlot.map_dataframe(plot_hline,y='obs',z='z')
plt.show()

resulting plot

Now seaborn maps the output from your function onto each facet of FacetGrid. Note, if you are using a different plotting function than axhline, you might not necessarily need to extract the value from the series.

Hope this helps someone!

like image 20
JRS Avatar answered Oct 16 '22 03:10

JRS


Seaborn is really just an interface for matplotlib, so you can use all of your standard matplotlib functions as well. Importing pyplot and plotting a red horizontal line after your seaborn plot works for me.

import numpy as np, seaborn as sns, pandas as pd
import matplotlib.pyplt as plt
np.random.seed(5)

myData = pd.DataFrame({'x' :  np.arange(1, 101), 'y': np.random.normal(0, 4, 100)})

sns.lmplot("x", "y", data=myData, line_kws={'xdata': '0,1','ydata': '0,0','color': 'k', 'linestyle':'-', 'linewidth':'5'}, fit_reg=False)
plt.plot(np.linspace(-20,120,1000), [0]*1000, 'r')

My image is here - http://i.imgur.com/J7Lvt52.png

like image 22
Tim Avatar answered Oct 16 '22 01:10

Tim