Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can pandas groupby transform a DataFrame into a Series?

I would like to use pandas and statsmodels to fit a linear model on subsets of a dataframe and return the predicted values. However, I am having trouble figuring out the right pandas idiom to use. Here is what I am trying to do:

import pandas as pd
import statsmodels.formula.api as sm
import seaborn as sns

tips = sns.load_dataset("tips")
def fit_predict(df):
    m = sm.ols("tip ~ total_bill", df).fit()
    return pd.Series(m.predict(df), index=df.index)
tips["predicted_tip"] = tips.groupby("day").transform(fit_predict)

This raises the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-139-b3d2575e2def> in <module>()
----> 1 tips["predicted_tip"] = tips.groupby("day").transform(fit_predict)

/Users/mwaskom/anaconda/lib/python2.7/site-packages/pandas/core/groupby.pyc in transform(self, func, *args, **kwargs)
   3033                     return self._transform_general(func, *args, **kwargs)
   3034         except:
-> 3035             return self._transform_general(func, *args, **kwargs)
   3036 
   3037         # a reduction transform

/Users/mwaskom/anaconda/lib/python2.7/site-packages/pandas/core/groupby.pyc in _transform_general(self, func, *args, **kwargs)
   2988                     group.T.values[:] = res
   2989                 else:
-> 2990                     group.values[:] = res
   2991 
   2992                 applied.append(group)

ValueError: could not broadcast input array from shape (62) into shape (62,6)

The error makes sense in that I think .transform wants to map a DataFrame to a DataFrame. But is there a way to do a groupby operation on a DataFrame, pass each chunk into a function that reduces it to a Series (with the same index), and then combine the resulting Series into something that can be inserted into the original dataframe?

like image 772
mwaskom Avatar asked Oct 31 '22 18:10

mwaskom


1 Answers

The top part here is the same, I'm just using a toy dataset b/c I'm behind a firewall.

tips = pd.DataFrame({ 'day':list('MMMFFF'), 'tip':range(6), 
                      'total_bill':[10,40,20,80,50,40] })

def fit_predict(df):
    m = sm.ols("tip ~ total_bill", df).fit()
    return pd.Series(m.predict(df), index=df.index)

If you change 'transform' to 'apply', you'll get:

tips.groupby("day").apply(fit_predict)

day   
F    3    2.923077
     4    4.307692
     5    4.769231
M    0    0.714286
     1    1.357143
     2    0.928571

That's not quite what you want, but if you drop level=0, you can proceed as desired:

tips['predicted'] = tips.groupby("day").apply(fit_predict).reset_index(level=0,drop=True)

  day  tip  total_bill  predicted
0   M    0          10   0.714286
1   M    1          40   1.357143
2   M    2          20   0.928571
3   F    3          80   2.923077
4   F    4          50   4.307692
5   F    5          40   4.769231
like image 135
JohnE Avatar answered Nov 13 '22 01:11

JohnE