Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Overlay a line function on a scatter plot - seaborn

My challenge is to overlay a custom line function graph over a scatter plot I already have, the code looks like follows:

base_beta = results.params
X_plot = np.linspace(0,1,400)

g = sns.FacetGrid(data, size = 6)
g = g.map(plt.scatter, "usable_area", "price", edgecolor="w")

Where base_beta is only a constant, and then one coefficient. Basically, I want to overlay a function that plots a line y = constant + coefficient * x

I tried to overlay a line using this but it did not work.

g = g.map_dataframe(plt.plot, X_plot, X_plot*base_beta[1]+base_beta[0], 'r-')
plt.show()

The current scatter plot looks like so:
enter image description here

Can any one help me with this?

--ATTEMPT 1

base_beta = results.params
X_plot = np.linspace(0,1,400)
Y_plot = base_beta [0] + base_beta[1]*X_plot

g = sns.FacetGrid(data, size = 6)
g = g.map(plt.scatter, "usable_area", "price", edgecolor="w")
plt.plot(X_plot, Y_plot, color='r')
plt.show()

Resulted in the same graph but no line: enter image description here

like image 220
Svarto Avatar asked Oct 06 '17 04:10

Svarto


People also ask

How do I add a line to Seaborn plot?

Seaborn's refline() function to add horizontal/vertical lines in subplots. To add a horizontal and vertical line we can use Seaborn's refline() function with x and y y co-ordinates for the locations of the horizontal and vertical lines.


2 Answers

You can just call plt.plot to plot a line over the data.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

data = pd.DataFrame()
data['usable_area'] = 5*np.random.random(200)
data['price'] =  10*data['usable_area']+10*np.random.random(200)

X_plot = np.linspace(0, 7, 100)
Y_plot = 10*X_plot+5

g = sns.FacetGrid(data, size = 6)
g = g.map(plt.scatter, "usable_area", "price", edgecolor="w")
plt.plot(X_plot, Y_plot, color='r')
plt.show()

Produces:

enter image description here

like image 84
Robbie Avatar answered Nov 10 '22 11:11

Robbie


  • It is now recommended to use figure-level functions like seaborn.relplot or seaborn.regplot instead of directly using seaborn.FacetGrid
  • Tested in python 3.8.12, pandas 1.3.3, matplotlib 3.4.3, seaborn 0.11.2

Sample Data and Imports

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# create a dataframe with sample x and y
np.random.seed(365)
x = 5*np.random.random(200)
df = pd.DataFrame({'x': x, 'y': 10*x+10*np.random.random(200)})

# add custom line to the dataframe
base_beta = [10, 5]
df['y_line'] = base_beta[0] + base_beta[1]*df.x

display(df.head())
          x          y     y_line
0  4.707279  50.634968  33.536394
1  3.208014  33.890507  26.040068
2  3.423052  37.853276  27.115262
3  2.942810  29.899257  24.714052
4  2.719436  36.932170  23.597180

Add Custom Line to Scatter Plot

sns.relplot with .map or .map_dataframe

  • Apply an axes-level plotting function (e.g. sns.lineplot) to each facet of the figure-level plot.
  • seaborn: Building structured multi-plot grids
p1 = sns.relplot(kind='scatter', x='x', y='y', data=df, height=3.5, aspect=1.5)
p1.map_dataframe(sns.lineplot, 'x', 'y_line', color='g')

enter image description here

sns.scatterplot with sns.lineplot

  • Plot two axes-level plots to the same figure
fig, ax = plt.subplots(figsize=(6, 4))
p1 = sns.scatterplot(data=df, x='x', y='y', ax=ax)
p2 = sns.lineplot(data=df, x='x', y='y_line', color='g', ax=ax)

enter image description here


Regression Line to a Scatter Plot

  • For a regression line
    • Use seaborn.lmplot for figure-level regression plot
    • Use seaborn.regplot for an axes-level regression plot.

sns.lmplot

p1 = sns.lmplot(data=df, x='x', y='y', line_kws={'color': 'g'}, height=3.5, aspect=1.5)

enter image description here

sns.regplot

p2 = sns.regplot(data=df, x='x', y='y', line_kws={'color': 'g'})

enter image description here

like image 20
Trenton McKinney Avatar answered Nov 10 '22 11:11

Trenton McKinney