Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Showing data and model predictions in one plot using Seaborn and Statsmodels

Seaborn is a great package for doing some high-level plotting with pretty outputs. However, I'm struggling a little with using Seaborn to overlay both data and model predictions from an externally-fit model. In this example I am fitting models in Statsmodels that are too complex for Seaborn to do out-of-the-box, but I think the problem is more general (i.e. if I have model predictions and want to visualise both them and data using Seaborn).

Let's start with imports and a dataset:

import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.formula.api as smf
import patsy
import itertools
import matplotlib.pyplot as plt

np.random.seed(12345)

# make a data frame with one continuous and two categorical variables:
df = pd.DataFrame({'x1': np.random.normal(size=100),
                     'x2': np.tile(np.array(['a', 'b']), 50),
                     'x3': np.repeat(np.array(['c', 'd']), 50)})

# create a design matrix using patsy:
X = patsy.dmatrix('x1 * x2 * x3', df)

# some random beta weights:
betas = np.random.normal(size=X.shape[1])

# create the response variable as the noisy linear combination of predictors:
df['y'] = np.inner(X, betas) + np.random.normal(size=100)

We fit a model in statsmodels containing all predictor variables and their interactions:

# fit a model with all interactions
fit = smf.ols('y ~ x1 * x2 * x3', df).fit()
print(fit.summary())

Since in this case we have all combinations of variables specified, and our model predictions are linear, it would suffice for plotting to add a new "predictions" column to the dataframe containing the model predictions. However, that's not very general (imagine our model is nonlinear and so we want our plots to show smooth curves), so instead I make a new dataframe with all combinations of predictors, then generate predictions:

# create a new dataframe of predictions, using pandas' expand grid:
def expand_grid(data_dict):
    """ A port of R's expand.grid function for use with Pandas dataframes.

    from http://pandas.pydata.org/pandas-docs/stable/cookbook.html?highlight=expand%20grid

    """
    rows = itertools.product(*data_dict.values())
    return pd.DataFrame.from_records(rows, columns=data_dict.keys())


# build a new matrix with expand grid:

preds = expand_grid(
                {'x1': np.linspace(df['x1'].min(), df['x1'].max(), 2),
                 'x2': ['a', 'b'],
                 'x3': ['c', 'd']})
preds['yhat'] = fit.predict(preds)

The preds dataframe looks like this:

  x3        x1 x2      yhat
0  c -2.370232  a -1.555902
1  c -2.370232  b -2.307295
2  c  3.248944  a -1.555902
3  c  3.248944  b -2.307295
4  d -2.370232  a -1.609652
5  d -2.370232  b -2.837075
6  d  3.248944  a -1.609652
7  d  3.248944  b -2.837075

Since Seaborn plot commands (unlike ggplot2 commands in R) appear to accept one and only one dataframe, we need to merge our predictions into the raw data:

# append to df:
merged = df.append(preds)

We can now plot the model predictions along with the data, with our continuous variable x1 as the x-axis:

# plot using seaborn:
sns.set_style('white')
sns.set_context('talk')
g = sns.FacetGrid(merged, hue='x2', col='x3', size=5)
# use the `map` method to add stuff to the facetgrid axes:
g.map(plt.plot, "x1", "yhat")
g.map(plt.scatter, "x1", "y")
g.add_legend()
g.fig.subplots_adjust(wspace=0.3)
sns.despine(offset=10);

enter image description here

So far so good. Now imagine that we didn't measure the continuous variable x1, and we only know about the other two (categorical) variables (i.e., we have a 2x2 factorial design). How can we plot the model predictions against data in this case?

fit = smf.ols('y ~ x2 * x3', df).fit()
print(fit.summary())

preds = expand_grid(
                {'x2': ['a', 'b'],
                 'x3': ['c', 'd']})
preds['yhat'] = fit.predict(preds)
print(preds)

# append to df:
merged = df.append(preds)

Well, we can plot the model predictions using sns.pointplot or similar, like so:

# plot using seaborn:
g = sns.FacetGrid(merged, hue='x3', size=4)
g.map(sns.pointplot, 'x2', 'yhat')
g.add_legend();
sns.despine(offset=10);

enter image description here

Or the data using sns.factorplot like so:

g = sns.factorplot('x2', 'y', hue='x3', kind='point', data=merged)
sns.despine(offset=10);
g.savefig('tmp3.png')

enter image description here

But I do not see how to produce a plot similar to the first one (i.e. lines for model predictions using plt.plot, a scatter of points for data using plt.scatter). The reason is that the x2 variable I'm trying to use as the x-axis is a string / object, so the pyplot commands don't know what to do with them.

like image 824
tsawallis Avatar asked Jan 30 '15 15:01

tsawallis


People also ask

How do you display the linear correlation of two data sets using seaborn?

Regression plots in seaborn can be easily implemented with the help of the lmplot() function. lmplot() can be understood as a function that basically creates a linear model plot. lmplot() makes a very simple linear regression plot.It creates a scatter plot with a linear fit on top of it.

Why is seaborn better than Matplotlib?

Seaborn is more comfortable in handling Pandas data frames. It uses basic sets of methods to provide beautiful graphics in python. Matplotlib works efficiently with data frames and arrays.It treats figures and axes as objects. It contains various stateful APIs for plotting.

Does seaborn have interactive plots?

Behind the scenes, seaborn uses matplotlib to draw its plots. For interactive work, it's recommended to use a Jupyter/IPython interface in matplotlib mode, or else you'll have to call matplotlib. pyplot. show() when you want to see the plot.


1 Answers

As I mention in my comments, there are two ways I would think about doing this.

The first is to define a function that does the fit and then plots and pass it to FacetGrid.map:

import pandas as pd
import seaborn as sns
tips = sns.load_dataset("tips")

def plot_good_tip(day, total_bill, **kws):

    expected_tip = (total_bill.groupby(day)
                              .mean()
                              .apply(lambda x: x * .2)
                              .reset_index(name="tip"))
    sns.pointplot(expected_tip.day, expected_tip.tip,
                  linestyles=["--"], markers=["D"])

g = sns.FacetGrid(tips, col="sex", size=5)
g.map(sns.pointplot, "day", "tip")
g.map(plot_good_tip, "day", "total_bill")
g.set_axis_labels("day", "tip")

enter image description here

The second is the compute the predicted values and then merge them into your DataFrame with an additional variable that identifies what is data and what is model:

tip_predict = (tips.groupby(["day", "sex"])
                   .total_bill
                   .mean()
                   .apply(lambda x: x * .2)
                   .reset_index(name="tip"))
tip_all = pd.concat(dict(data=tips[["day", "sex", "tip"]], model=tip_predict),
                    names=["kind"]).reset_index()

sns.factorplot("day", "tip", "kind", data=tip_all, col="sex",
               kind="point", linestyles=["-", "--"], markers=["o", "D"])

enter image description here

like image 126
mwaskom Avatar answered Oct 22 '22 17:10

mwaskom