Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Return predicted values from a rolling regression grouped by id using Pandas

I am trying to calculate monthly rolling window regressions and return predicted values as a new column in the data frame. I know that Pandas has rolling regression capabilities (pandas.ols) that are in the process of being depreciated, so I'm interested in a solution that uses statsmodels or something similar.

I'd like to calculate monthly rolling regressions (12 month window, 6 month minimum) and save each month's prediction back to a new column in the original data frame. While my question is different, the closest solution I've found is in the answer to this question. Based on that answer I've tried this (the data is below):

import pandas as pd 
import statsmodels.api as sm

def grp_ols_predict(df, xcols,  ycol):
    return sm.OLS(df[ycol], df[xcols]).fit().predict()
retdata['predicted_y'] = retdata.groupby('id').apply(grp_ols_predict, xcols=['constant','x1', 'x2', 'x3'], ycol='y')

There are two issues unresolved at this point.
1. This code runs without errors but returns all NaN values for predicted_y.
2. The regression above is not rolling window. The syntax for this is straightforward in pandas.ols, but not in statsmodels. However, it seems that the idea is for the pandas.ols syntax to work in statsmodels at some point. The following code is rolling window, but will be depriciated in a future version and is not grouped by id:

model = pd.ols(y='y', x=retdata[['x1','x2','x3']], window_type='rolling', window=12, min_periods=6, intercept=True)
retdata['predicted_y'] = model.y_predict

My question is essentially "appending predicted values and residuals to pandas dataframe" with two additional complications (1) rolling window and (2) grouping by id.

Finally, a sample of the data I'm using:

{'constant': {0: 1,  1: 1,  2: 1,  3: 1,  4: 1,  5: 1,  6: 1,  7: 1,  8: 1,  9: 1,  10: 1,  11: 1,  12: 1,  13: 1,  14: 1,  15: 1,  16: 1,  17: 1,  18: 1,  19: 1,  20: 1,  21: 1,  22: 1,  23: 1,  24: 1,  25: 1,  26: 1,  27: 1,  28: 1,  29: 1,  30: 1,  31: 1,  32: 1,  33: 1,  34: 1,  35: 1,  36: 1,  37: 1,  38: 1,  39: 1,  40: 1,  41: 1,  42: 1,  43: 1,  44: 1,  45: 1,  46: 1,  47: 1,  48: 1,  49: 1,  50: 1,  51: 1,  52: 1,  53: 1,  54: 1,  55: 1,  56: 1,  57: 1,  58: 1,  59: 1,  60: 1,  61: 1,  62: 1,  63: 1,  64: 1,  65: 1,  66: 1,  67: 1,  68: 1,  69: 1,  70: 1,  71: 1,  72: 1,  73: 1,  74: 1,  75: 1,  76: 1,  77: 1,  78: 1,  79: 1,  80: 1,  81: 1,  82: 1,  83: 1},
'id': {0: 11111,  1: 11111,  2: 11111,  3: 11111,  4: 11111,  5: 11111,  6: 11111,  7: 11111,  8: 11111,  9: 11111,  10: 11111,  11: 11111,  12: 11111,  13: 11111,  14: 11111,  15: 11111,  16: 11111,  17: 11111,  18: 11111,  19: 11111,  20: 11111,  21: 11111,  22: 11111,  23: 11111,  24: 22222,  25: 22222,  26: 22222,  27: 22222,  28: 22222,  29: 22222,  30: 22222,  31: 22222,  32: 22222,  33: 22222,  34: 22222,  35: 22222,  36: 22222,  37: 22222,  38: 22222,  39: 22222,  40: 22222,  41: 22222,  42: 22222,  43: 22222,  44: 22222,  45: 22222,  46: 22222,  47: 22222,  48: 22222,  49: 22222,  50: 22222,  51: 22222,  52: 22222,  53: 22222,  54: 22222,  55: 22222,  56: 22222,  57: 22222,  58: 22222,  59: 22222,  60: 33333,  61: 33333,  62: 33333,  63: 33333,  64: 33333,  65: 33333,  66: 33333,  67: 33333,  68: 33333,  69: 33333,  70: 33333,  71: 33333,  72: 33333,  73: 33333,  74: 33333,  75: 33333,  76: 33333,  77: 33333,  78: 33333,  79: 33333,  80: 33333,  81: 33333,  82: 33333,  83: 33333},
'month': {0: 1,  1: 2,  2: 3,  3: 4,  4: 5,  5: 6,  6: 7,  7: 8,  8: 9,  9: 10,  10: 11,  11: 12,  12: 1,  13: 2,  14: 3,  15: 4,  16: 5,  17: 6,  18: 7,  19: 8,  20: 9,  21: 10,  22: 11,  23: 12,  24: 1,  25: 2,  26: 3,  27: 4,  28: 5,  29: 6,  30: 7,  31: 8,  32: 9,  33: 10,  34: 11,  35: 12,  36: 1,  37: 2,  38: 3,  39: 4,  40: 5,  41: 6,  42: 7,  43: 8,  44: 9,  45: 10,  46: 11,  47: 12,  48: 1,  49: 2,  50: 3,  51: 4,  52: 5,  53: 6,  54: 7,  55: 8,  56: 9,  57: 10,  58: 11,  59: 12,  60: 1,  61: 2,  62: 3,  63: 4,  64: 5,  65: 6,  66: 7,  67: 8,  68: 9,  69: 10,  70: 11,  71: 12,  72: 1,  73: 2,  74: 3,  75: 4,  76: 5,  77: 6,  78: 7,  79: 8,  80: 9,  81: 10,  82: 11,  83: 12},
'x1': {0: 4.8399999999999999,  1: 1.4099999999999999,  2: 4.1299999999999999,  3: 3.1499999999999999,  4: -3.98,  5: -0.10000000000000001,  6: -4.5,  7: 3.79,  8: -0.84999999999999998,  9: -4.4199999999999999,  10: -0.46000000000000002,  11: 8.7100000000000009,  12: 2.4900000000000002,  13: 2.8700000000000001,  14: 0.63,  15: 0.28999999999999998,  16: 1.25,  17: -2.4300000000000002,  18: -0.80000000000000004,  19: 3.2599999999999998,  20: -1.1399999999999999,  21: 0.52000000000000002,  22: 4.5999999999999996,  23: 0.62,  24: 4.8399999999999999,  25: 1.4099999999999999,  26: 4.1299999999999999,  27: 3.1499999999999999,  28: -3.98,  29: -0.10000000000000001,  30: -4.5,  31: 3.79,  32: -0.84999999999999998,  33: -4.4199999999999999,  34: -0.46000000000000002,  35: 8.7100000000000009,  36: 2.4900000000000002,  37: 2.8700000000000001,  38: 0.63,  39: 0.28999999999999998,  40: 1.25,  41: -2.4300000000000002,  42: -0.80000000000000004,  43: 3.2599999999999998,  44: -1.1399999999999999,  45: 0.52000000000000002,  46: 4.5999999999999996,  47: 0.62,  48: -3.29,  49: -4.8499999999999996,  50: -1.29,  51: -5.6799999999999997,  52: -2.9399999999999999,  53: -1.5600000000000001,  54: 5.04,  55: -3.8399999999999999,  56: 4.75,  57: -0.85999999999999999,  58: -12.74,  59: 0.57999999999999996,  60: 5.5700000000000003,  61: 1.29,  62: 4.0300000000000002,  63: 1.55,  64: 2.7999999999999998,  65: -1.2,  66: 5.6500000000000004,  67: -2.71,  68: 3.77,  69: 4.1799999999999997,  70: 3.1200000000000001,  71: 2.8100000000000001,  72: -3.3199999999999998,  73: 4.6500000000000004,  74: 0.42999999999999999,  75: -0.19,  76: 2.0600000000000001,  77: 2.6099999999999999,  78: -2.04,  79: 4.2400000000000002,  80: -1.97,  81: 2.52,  82: 2.5499999999999998,  83: -0.059999999999999998},
'x2': {0: 7.4400000000000004,  1: 1.8999999999999999,  2: 2.5699999999999998,  3: -0.47999999999999998,  4: -1.1000000000000001,  5: -1.4299999999999999,  6: -1.5,  7: -0.19,  8: 0.40999999999999998,  9: -1.78,  10: -2.8300000000000001,  11: 3.2799999999999998,  12: 6.1100000000000003,  13: 1.3899999999999999,  14: -0.27000000000000002,  15: -0.02,  16: -2.79,  17: 0.32000000000000001,  18: -2.8900000000000001,  19: -4.0700000000000003,  20: -2.6899999999999999,  21: -2.71,  22: -1.1200000000000001,  23: -1.8600000000000001,  24: 7.4400000000000004,  25: 1.8999999999999999,  26: 2.5699999999999998,  27: -0.47999999999999998,  28: -1.1000000000000001,  29: -1.4299999999999999,  30: -1.5,  31: -0.19,  32: 0.40999999999999998,  33: -1.78,  34: -2.8300000000000001,  35: 3.2799999999999998,  36: 6.1100000000000003,  37: 1.3899999999999999,  38: -0.27000000000000002,  39: -0.02,  40: -2.79,  41: 0.32000000000000001,  42: -2.8900000000000001,  43: -4.0700000000000003,  44: -2.6899999999999999,  45: -2.71,  46: -1.1200000000000001,  47: -1.8600000000000001,  48: -3.5,  49: -3.9900000000000002,  50: -2.8700000000000001,  51: -3.9900000000000002,  52: -6.1200000000000001,  53: -2.9399999999999999,  54: 7.8600000000000003,  55: -2.04,  56: 2.9100000000000001,  57: -0.17000000000000001,  58: -7.7000000000000002,  59: -5.3300000000000001,  60: 0.44,  61: -0.42999999999999999,  62: 0.83999999999999997,  63: -2.4300000000000002,  64: 1.6899999999999999,  65: 1.1699999999999999,  66: 1.8799999999999999,  67: 0.25,  68: 2.9399999999999999,  69: -1.52,  70: 1.25,  71: -0.47999999999999998,  72: 0.87,  73: 0.34000000000000002,  74: -1.8500000000000001,  75: -4.1900000000000004,  76: -1.8500000000000001,  77: 3.0099999999999998,  78: -4.2199999999999998,  79: 0.40000000000000002,  80: -3.7999999999999998,  81: 4.2800000000000002,  82: -2.0499999999999998,  83: 2.5899999999999999},
'x3': {0: 1.3500000000000001,  1: -1.3400000000000001,  2: -4.0,  3: 0.73999999999999999,  4: -1.3799999999999999,  5: -2.0,  6: 0.14000000000000001,  7: 2.7200000000000002,  8: -2.9500000000000002,  9: -0.47999999999999998,  10: -1.75,  11: -0.23999999999999999,  12: 2.0600000000000001,  13: -2.75,  14: -1.6599999999999999,  15: 0.39000000000000001,  16: -2.73,  17: -2.4199999999999999,  18: 0.77000000000000002,  19: 4.6399999999999997,  20: 0.5,  21: 1.3200000000000001,  22: 4.7599999999999998,  23: -2.2599999999999998,  24: 1.3500000000000001,  25: -1.3400000000000001,  26: -4.0,  27: 0.73999999999999999,  28: -1.3799999999999999,  29: -2.0,  30: 0.14000000000000001,  31: 2.7200000000000002,  32: -2.9500000000000002,  33: -0.47999999999999998,  34: -1.75,  35: -0.23999999999999999,  36: 2.0600000000000001,  37: -2.75,  38: -1.6599999999999999,  39: 0.39000000000000001,  40: -2.73,  41: -2.4199999999999999,  42: 0.77000000000000002,  43: 4.6399999999999997,  44: 0.5,  45: 1.3200000000000001,  46: 4.7599999999999998,  47: -2.2599999999999998,  48: 2.7000000000000002,  49: 1.7,  50: 2.8300000000000001,  51: 5.6900000000000004,  52: 0.20999999999999999,  53: 1.4199999999999999,  54: -5.1799999999999997,  55: 1.1899999999999999,  56: 2.1099999999999999,  57: 1.74,  58: 4.0099999999999998,  59: 4.2400000000000002,  60: 0.94999999999999996,  61: 0.11,  62: -0.26000000000000001,  63: 0.56999999999999995,  64: 2.4900000000000002,  65: -0.13,  66: 0.60999999999999999,  67: -2.77,  68: -1.2,  69: 1.1000000000000001,  70: 0.26000000000000001,  71: -0.31,  72: -2.1299999999999999,  73: -0.37,  74: 5.0300000000000002,  75: 1.1000000000000001,  76: -0.35999999999999999,  77: -0.66000000000000003,  78: -0.02,  79: -0.55000000000000004,  80: -1.1899999999999999,  81: -1.6799999999999999,  82: -2.98,  83: 2.1200000000000001},
'y': {0: 37.543945819999998,  1: 8.9742475529999997,  2: -2.3528754309999997,  3: 13.13251636,  4: -1.60429428,  5: -11.956497779999999,  6: -19.876604879999999,  7: -2.325516618,  8: -4.7618724569999999,  9: 3.1666054689999998,  10: -1.625982086,  11: 23.14051619,  12: 36.241578869999998,  13: -4.0393970439999993,  14: -1.5464071159999999,  15: -5.8638777849999997,  16: 1.1173513309999998,  17: -7.7348398829999994,  18: 1.1975707259999999,  19: 8.1657380679999996,  20: 1.0988696200000001,  21: -4.8912916910000002,  22: 15.31432558,  23: -0.49755575099999999,  24: 2.439007991,  25: 3.7788248100000001,  26: 6.2406021170000008,  27: 0.070041193000000002,  28: -8.2320061649999996,  29: -3.0580604539999996,  30: -8.1230234560000003,  31: 4.824015073,  32: -0.082216824000000008,  33: -1.0699493369999999,  34: 2.0965058669999999,  35: 10.147223650000001,  36: 9.3610165409999997,  37: 0.50276726500000002,  38: 3.731305892,  39: 0.98107468400000009,  40: 3.3937931360000002,  41: -1.445663699,  42: 2.2321845640000002,  43: 2.2707284099999998,  44: -0.48955173399999996,  45: -5.1661444639999994,  46: 1.776962626,  47: 2.8132786730000001,  48: 8.3333586369999999,  49: -0.59700207599999999,  50: 0.0,  51: -5.4461723210000006,  52: -3.2260780789999997,  53: 0.71489267299999992,  54: -0.78864414099999991,  55: -3.936371727,  56: -14.285801190000001,  57: 8.6241378770000008,  58: -5.0419731539999999,  59: -6.8867527329999998,  60: 2.7716522460000004,  61: 2.1129326050000001,  62: 2.8956834530000002,  63: 15.714036009999999,  64: 6.1329305139999999,  65: -1.017191977,  66: -7.8303661889999994,  67: 5.6218592960000002,  68: -0.35928143700000004,  69: 6.385216346,  70: 8.4875017649999993,  71: -1.8882769469999998,  72: 1.1494252870000001,  73: 1.9820295980000002,  74: 6.9955625160000006,  75: -1.4393754569999999,  76: 2.0297029700000002,  77: 1.8563751830000002,  78: 3.5011990410000005,  79: 5.9082483779999997,  80: 2.0471054369999999,  81: 1.272648835,  82: 2.49201278,  83: -2.844593181},
'year': {0: 1971,  1: 1971,  2: 1971,  3: 1971,  4: 1971,  5: 1971,  6: 1971,  7: 1971,  8: 1971,  9: 1971,  10: 1971,  11: 1971,  12: 1972,  13: 1972,  14: 1972,  15: 1972,  16: 1972,  17: 1972,  18: 1972,  19: 1972,  20: 1972,  21: 1972,  22: 1972,  23: 1972,  24: 1971,  25: 1971,  26: 1971,  27: 1971,  28: 1971,  29: 1971,  30: 1971,  31: 1971,  32: 1971,  33: 1971,  34: 1971,  35: 1971,  36: 1972,  37: 1972,  38: 1972,  39: 1972,  40: 1972,  41: 1972,  42: 1972,  43: 1972,  44: 1972,  45: 1972,  46: 1972,  47: 1972,  48: 1973,  49: 1973,  50: 1973,  51: 1973,  52: 1973,  53: 1973,  54: 1973,  55: 1973,  56: 1973,  57: 1973,  58: 1973,  59: 1973,  60: 2013,  61: 2013,  62: 2013,  63: 2013,  64: 2013,  65: 2013,  66: 2013,  67: 2013,  68: 2013,  69: 2013,  70: 2013,  71: 2013,  72: 2014,  73: 2014,  74: 2014,  75: 2014,  76: 2014,  77: 2014,  78: 2014,  79: 2014,  80: 2014,  81: 2014,   82: 2014,    83: 2014}}
like image 534
Arthur Morris Avatar asked Mar 02 '17 15:03

Arthur Morris


1 Answers

pandas' rolling seems to have some limitations. First, it seems impossible to pass an entire frame of data via the apply. Instead, only values of a single column are passed. To get around this, we pass the index via the apply which allows to get the relevant data frame subset within the apply function itself.

Second, the returned value needs to be a float. This is of no use here because the sm.OLS.predict returns an iterable of values. To fix this, we save the results in an extra container as a side effect and extract it later on.

def ols_predict(indices, result, ycol, xcols):
    roll_df = df.loc[indices] # get relevant data frame subset
    result[indices[-1]] = sm.OLS(roll_df[ycol], roll_df[xcols]).fit().predict()
    return 0 # value is irrelvant here

# define kwargs to be fet to the ols_predict
kwargs = {"xcols": ['constant','x1', 'x2', 'x3'], 
          "ycol": 'y', "result": {}}

# iterate id's sub data frames and call ols for rolling windows
df["identifier"] = df.index
for idx, sub_df in df.groupby("id"):
    sub_df["identifier"].rolling(12, min_periods=6).apply(ols_predict, kwargs=kwargs)

# write results back to original df
df["parameters"] = pd.Series(kwargs["result"])

# showing the last 5 computed values
print(df["parameters"].tail())

79    [2.71069564365, 3.86510820198, 3.65972798601, ...
80    [4.05363775104, 4.22653362401, 3.03918230523, ...
81    [3.55589161647, 2.49348201521, 1.20113347853, ...
82    [2.28561308212, 1.0537258681, 2.40806914305, 4...
83    [-0.428928897229, 3.22009689097, 3.30943586961...
Name: parameters, dtype: object

Overall, the workarounds are rather ugly using side effects. However, it accomplishes what you require. You can now modify your OLS function to whatever needed.

like image 156
pansen Avatar answered Oct 23 '22 02:10

pansen