I am trying to calculate monthly rolling window regressions and return predicted values as a new column in the data frame. I know that Pandas has rolling regression capabilities (pandas.ols
) that are in the process of being depreciated, so I'm interested in a solution that uses statsmodels
or something similar.
I'd like to calculate monthly rolling regressions (12 month window, 6 month minimum) and save each month's prediction back to a new column in the original data frame. While my question is different, the closest solution I've found is in the answer to this question. Based on that answer I've tried this (the data is below):
import pandas as pd
import statsmodels.api as sm
def grp_ols_predict(df, xcols, ycol):
return sm.OLS(df[ycol], df[xcols]).fit().predict()
retdata['predicted_y'] = retdata.groupby('id').apply(grp_ols_predict, xcols=['constant','x1', 'x2', 'x3'], ycol='y')
There are two issues unresolved at this point.
1. This code runs without errors but returns all NaN
values for predicted_y
.
2. The regression above is not rolling window. The syntax for this is straightforward in pandas.ols
, but not in statsmodels
. However, it seems that the idea is for the pandas.ols
syntax to work in statsmodels
at some point. The following code is rolling window, but will be depriciated in a future version and is not grouped by id:
model = pd.ols(y='y', x=retdata[['x1','x2','x3']], window_type='rolling', window=12, min_periods=6, intercept=True)
retdata['predicted_y'] = model.y_predict
My question is essentially "appending predicted values and residuals to pandas dataframe" with two additional complications (1) rolling window and (2) grouping by id.
Finally, a sample of the data I'm using:
{'constant': {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 1, 28: 1, 29: 1, 30: 1, 31: 1, 32: 1, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1, 39: 1, 40: 1, 41: 1, 42: 1, 43: 1, 44: 1, 45: 1, 46: 1, 47: 1, 48: 1, 49: 1, 50: 1, 51: 1, 52: 1, 53: 1, 54: 1, 55: 1, 56: 1, 57: 1, 58: 1, 59: 1, 60: 1, 61: 1, 62: 1, 63: 1, 64: 1, 65: 1, 66: 1, 67: 1, 68: 1, 69: 1, 70: 1, 71: 1, 72: 1, 73: 1, 74: 1, 75: 1, 76: 1, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 1},
'id': {0: 11111, 1: 11111, 2: 11111, 3: 11111, 4: 11111, 5: 11111, 6: 11111, 7: 11111, 8: 11111, 9: 11111, 10: 11111, 11: 11111, 12: 11111, 13: 11111, 14: 11111, 15: 11111, 16: 11111, 17: 11111, 18: 11111, 19: 11111, 20: 11111, 21: 11111, 22: 11111, 23: 11111, 24: 22222, 25: 22222, 26: 22222, 27: 22222, 28: 22222, 29: 22222, 30: 22222, 31: 22222, 32: 22222, 33: 22222, 34: 22222, 35: 22222, 36: 22222, 37: 22222, 38: 22222, 39: 22222, 40: 22222, 41: 22222, 42: 22222, 43: 22222, 44: 22222, 45: 22222, 46: 22222, 47: 22222, 48: 22222, 49: 22222, 50: 22222, 51: 22222, 52: 22222, 53: 22222, 54: 22222, 55: 22222, 56: 22222, 57: 22222, 58: 22222, 59: 22222, 60: 33333, 61: 33333, 62: 33333, 63: 33333, 64: 33333, 65: 33333, 66: 33333, 67: 33333, 68: 33333, 69: 33333, 70: 33333, 71: 33333, 72: 33333, 73: 33333, 74: 33333, 75: 33333, 76: 33333, 77: 33333, 78: 33333, 79: 33333, 80: 33333, 81: 33333, 82: 33333, 83: 33333},
'month': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 1, 13: 2, 14: 3, 15: 4, 16: 5, 17: 6, 18: 7, 19: 8, 20: 9, 21: 10, 22: 11, 23: 12, 24: 1, 25: 2, 26: 3, 27: 4, 28: 5, 29: 6, 30: 7, 31: 8, 32: 9, 33: 10, 34: 11, 35: 12, 36: 1, 37: 2, 38: 3, 39: 4, 40: 5, 41: 6, 42: 7, 43: 8, 44: 9, 45: 10, 46: 11, 47: 12, 48: 1, 49: 2, 50: 3, 51: 4, 52: 5, 53: 6, 54: 7, 55: 8, 56: 9, 57: 10, 58: 11, 59: 12, 60: 1, 61: 2, 62: 3, 63: 4, 64: 5, 65: 6, 66: 7, 67: 8, 68: 9, 69: 10, 70: 11, 71: 12, 72: 1, 73: 2, 74: 3, 75: 4, 76: 5, 77: 6, 78: 7, 79: 8, 80: 9, 81: 10, 82: 11, 83: 12},
'x1': {0: 4.8399999999999999, 1: 1.4099999999999999, 2: 4.1299999999999999, 3: 3.1499999999999999, 4: -3.98, 5: -0.10000000000000001, 6: -4.5, 7: 3.79, 8: -0.84999999999999998, 9: -4.4199999999999999, 10: -0.46000000000000002, 11: 8.7100000000000009, 12: 2.4900000000000002, 13: 2.8700000000000001, 14: 0.63, 15: 0.28999999999999998, 16: 1.25, 17: -2.4300000000000002, 18: -0.80000000000000004, 19: 3.2599999999999998, 20: -1.1399999999999999, 21: 0.52000000000000002, 22: 4.5999999999999996, 23: 0.62, 24: 4.8399999999999999, 25: 1.4099999999999999, 26: 4.1299999999999999, 27: 3.1499999999999999, 28: -3.98, 29: -0.10000000000000001, 30: -4.5, 31: 3.79, 32: -0.84999999999999998, 33: -4.4199999999999999, 34: -0.46000000000000002, 35: 8.7100000000000009, 36: 2.4900000000000002, 37: 2.8700000000000001, 38: 0.63, 39: 0.28999999999999998, 40: 1.25, 41: -2.4300000000000002, 42: -0.80000000000000004, 43: 3.2599999999999998, 44: -1.1399999999999999, 45: 0.52000000000000002, 46: 4.5999999999999996, 47: 0.62, 48: -3.29, 49: -4.8499999999999996, 50: -1.29, 51: -5.6799999999999997, 52: -2.9399999999999999, 53: -1.5600000000000001, 54: 5.04, 55: -3.8399999999999999, 56: 4.75, 57: -0.85999999999999999, 58: -12.74, 59: 0.57999999999999996, 60: 5.5700000000000003, 61: 1.29, 62: 4.0300000000000002, 63: 1.55, 64: 2.7999999999999998, 65: -1.2, 66: 5.6500000000000004, 67: -2.71, 68: 3.77, 69: 4.1799999999999997, 70: 3.1200000000000001, 71: 2.8100000000000001, 72: -3.3199999999999998, 73: 4.6500000000000004, 74: 0.42999999999999999, 75: -0.19, 76: 2.0600000000000001, 77: 2.6099999999999999, 78: -2.04, 79: 4.2400000000000002, 80: -1.97, 81: 2.52, 82: 2.5499999999999998, 83: -0.059999999999999998},
'x2': {0: 7.4400000000000004, 1: 1.8999999999999999, 2: 2.5699999999999998, 3: -0.47999999999999998, 4: -1.1000000000000001, 5: -1.4299999999999999, 6: -1.5, 7: -0.19, 8: 0.40999999999999998, 9: -1.78, 10: -2.8300000000000001, 11: 3.2799999999999998, 12: 6.1100000000000003, 13: 1.3899999999999999, 14: -0.27000000000000002, 15: -0.02, 16: -2.79, 17: 0.32000000000000001, 18: -2.8900000000000001, 19: -4.0700000000000003, 20: -2.6899999999999999, 21: -2.71, 22: -1.1200000000000001, 23: -1.8600000000000001, 24: 7.4400000000000004, 25: 1.8999999999999999, 26: 2.5699999999999998, 27: -0.47999999999999998, 28: -1.1000000000000001, 29: -1.4299999999999999, 30: -1.5, 31: -0.19, 32: 0.40999999999999998, 33: -1.78, 34: -2.8300000000000001, 35: 3.2799999999999998, 36: 6.1100000000000003, 37: 1.3899999999999999, 38: -0.27000000000000002, 39: -0.02, 40: -2.79, 41: 0.32000000000000001, 42: -2.8900000000000001, 43: -4.0700000000000003, 44: -2.6899999999999999, 45: -2.71, 46: -1.1200000000000001, 47: -1.8600000000000001, 48: -3.5, 49: -3.9900000000000002, 50: -2.8700000000000001, 51: -3.9900000000000002, 52: -6.1200000000000001, 53: -2.9399999999999999, 54: 7.8600000000000003, 55: -2.04, 56: 2.9100000000000001, 57: -0.17000000000000001, 58: -7.7000000000000002, 59: -5.3300000000000001, 60: 0.44, 61: -0.42999999999999999, 62: 0.83999999999999997, 63: -2.4300000000000002, 64: 1.6899999999999999, 65: 1.1699999999999999, 66: 1.8799999999999999, 67: 0.25, 68: 2.9399999999999999, 69: -1.52, 70: 1.25, 71: -0.47999999999999998, 72: 0.87, 73: 0.34000000000000002, 74: -1.8500000000000001, 75: -4.1900000000000004, 76: -1.8500000000000001, 77: 3.0099999999999998, 78: -4.2199999999999998, 79: 0.40000000000000002, 80: -3.7999999999999998, 81: 4.2800000000000002, 82: -2.0499999999999998, 83: 2.5899999999999999},
'x3': {0: 1.3500000000000001, 1: -1.3400000000000001, 2: -4.0, 3: 0.73999999999999999, 4: -1.3799999999999999, 5: -2.0, 6: 0.14000000000000001, 7: 2.7200000000000002, 8: -2.9500000000000002, 9: -0.47999999999999998, 10: -1.75, 11: -0.23999999999999999, 12: 2.0600000000000001, 13: -2.75, 14: -1.6599999999999999, 15: 0.39000000000000001, 16: -2.73, 17: -2.4199999999999999, 18: 0.77000000000000002, 19: 4.6399999999999997, 20: 0.5, 21: 1.3200000000000001, 22: 4.7599999999999998, 23: -2.2599999999999998, 24: 1.3500000000000001, 25: -1.3400000000000001, 26: -4.0, 27: 0.73999999999999999, 28: -1.3799999999999999, 29: -2.0, 30: 0.14000000000000001, 31: 2.7200000000000002, 32: -2.9500000000000002, 33: -0.47999999999999998, 34: -1.75, 35: -0.23999999999999999, 36: 2.0600000000000001, 37: -2.75, 38: -1.6599999999999999, 39: 0.39000000000000001, 40: -2.73, 41: -2.4199999999999999, 42: 0.77000000000000002, 43: 4.6399999999999997, 44: 0.5, 45: 1.3200000000000001, 46: 4.7599999999999998, 47: -2.2599999999999998, 48: 2.7000000000000002, 49: 1.7, 50: 2.8300000000000001, 51: 5.6900000000000004, 52: 0.20999999999999999, 53: 1.4199999999999999, 54: -5.1799999999999997, 55: 1.1899999999999999, 56: 2.1099999999999999, 57: 1.74, 58: 4.0099999999999998, 59: 4.2400000000000002, 60: 0.94999999999999996, 61: 0.11, 62: -0.26000000000000001, 63: 0.56999999999999995, 64: 2.4900000000000002, 65: -0.13, 66: 0.60999999999999999, 67: -2.77, 68: -1.2, 69: 1.1000000000000001, 70: 0.26000000000000001, 71: -0.31, 72: -2.1299999999999999, 73: -0.37, 74: 5.0300000000000002, 75: 1.1000000000000001, 76: -0.35999999999999999, 77: -0.66000000000000003, 78: -0.02, 79: -0.55000000000000004, 80: -1.1899999999999999, 81: -1.6799999999999999, 82: -2.98, 83: 2.1200000000000001},
'y': {0: 37.543945819999998, 1: 8.9742475529999997, 2: -2.3528754309999997, 3: 13.13251636, 4: -1.60429428, 5: -11.956497779999999, 6: -19.876604879999999, 7: -2.325516618, 8: -4.7618724569999999, 9: 3.1666054689999998, 10: -1.625982086, 11: 23.14051619, 12: 36.241578869999998, 13: -4.0393970439999993, 14: -1.5464071159999999, 15: -5.8638777849999997, 16: 1.1173513309999998, 17: -7.7348398829999994, 18: 1.1975707259999999, 19: 8.1657380679999996, 20: 1.0988696200000001, 21: -4.8912916910000002, 22: 15.31432558, 23: -0.49755575099999999, 24: 2.439007991, 25: 3.7788248100000001, 26: 6.2406021170000008, 27: 0.070041193000000002, 28: -8.2320061649999996, 29: -3.0580604539999996, 30: -8.1230234560000003, 31: 4.824015073, 32: -0.082216824000000008, 33: -1.0699493369999999, 34: 2.0965058669999999, 35: 10.147223650000001, 36: 9.3610165409999997, 37: 0.50276726500000002, 38: 3.731305892, 39: 0.98107468400000009, 40: 3.3937931360000002, 41: -1.445663699, 42: 2.2321845640000002, 43: 2.2707284099999998, 44: -0.48955173399999996, 45: -5.1661444639999994, 46: 1.776962626, 47: 2.8132786730000001, 48: 8.3333586369999999, 49: -0.59700207599999999, 50: 0.0, 51: -5.4461723210000006, 52: -3.2260780789999997, 53: 0.71489267299999992, 54: -0.78864414099999991, 55: -3.936371727, 56: -14.285801190000001, 57: 8.6241378770000008, 58: -5.0419731539999999, 59: -6.8867527329999998, 60: 2.7716522460000004, 61: 2.1129326050000001, 62: 2.8956834530000002, 63: 15.714036009999999, 64: 6.1329305139999999, 65: -1.017191977, 66: -7.8303661889999994, 67: 5.6218592960000002, 68: -0.35928143700000004, 69: 6.385216346, 70: 8.4875017649999993, 71: -1.8882769469999998, 72: 1.1494252870000001, 73: 1.9820295980000002, 74: 6.9955625160000006, 75: -1.4393754569999999, 76: 2.0297029700000002, 77: 1.8563751830000002, 78: 3.5011990410000005, 79: 5.9082483779999997, 80: 2.0471054369999999, 81: 1.272648835, 82: 2.49201278, 83: -2.844593181},
'year': {0: 1971, 1: 1971, 2: 1971, 3: 1971, 4: 1971, 5: 1971, 6: 1971, 7: 1971, 8: 1971, 9: 1971, 10: 1971, 11: 1971, 12: 1972, 13: 1972, 14: 1972, 15: 1972, 16: 1972, 17: 1972, 18: 1972, 19: 1972, 20: 1972, 21: 1972, 22: 1972, 23: 1972, 24: 1971, 25: 1971, 26: 1971, 27: 1971, 28: 1971, 29: 1971, 30: 1971, 31: 1971, 32: 1971, 33: 1971, 34: 1971, 35: 1971, 36: 1972, 37: 1972, 38: 1972, 39: 1972, 40: 1972, 41: 1972, 42: 1972, 43: 1972, 44: 1972, 45: 1972, 46: 1972, 47: 1972, 48: 1973, 49: 1973, 50: 1973, 51: 1973, 52: 1973, 53: 1973, 54: 1973, 55: 1973, 56: 1973, 57: 1973, 58: 1973, 59: 1973, 60: 2013, 61: 2013, 62: 2013, 63: 2013, 64: 2013, 65: 2013, 66: 2013, 67: 2013, 68: 2013, 69: 2013, 70: 2013, 71: 2013, 72: 2014, 73: 2014, 74: 2014, 75: 2014, 76: 2014, 77: 2014, 78: 2014, 79: 2014, 80: 2014, 81: 2014, 82: 2014, 83: 2014}}
pandas' rolling
seems to have some limitations. First, it seems impossible to pass an entire frame of data via the apply
. Instead, only values of a single column are passed. To get around this, we pass the index via the apply
which allows to get the relevant data frame subset within the apply function itself.
Second, the returned value needs to be a float. This is of no use here because the sm.OLS.predict
returns an iterable of values. To fix this, we save the results in an extra container as a side effect and extract it later on.
def ols_predict(indices, result, ycol, xcols):
roll_df = df.loc[indices] # get relevant data frame subset
result[indices[-1]] = sm.OLS(roll_df[ycol], roll_df[xcols]).fit().predict()
return 0 # value is irrelvant here
# define kwargs to be fet to the ols_predict
kwargs = {"xcols": ['constant','x1', 'x2', 'x3'],
"ycol": 'y', "result": {}}
# iterate id's sub data frames and call ols for rolling windows
df["identifier"] = df.index
for idx, sub_df in df.groupby("id"):
sub_df["identifier"].rolling(12, min_periods=6).apply(ols_predict, kwargs=kwargs)
# write results back to original df
df["parameters"] = pd.Series(kwargs["result"])
# showing the last 5 computed values
print(df["parameters"].tail())
79 [2.71069564365, 3.86510820198, 3.65972798601, ...
80 [4.05363775104, 4.22653362401, 3.03918230523, ...
81 [3.55589161647, 2.49348201521, 1.20113347853, ...
82 [2.28561308212, 1.0537258681, 2.40806914305, 4...
83 [-0.428928897229, 3.22009689097, 3.30943586961...
Name: parameters, dtype: object
Overall, the workarounds are rather ugly using side effects. However, it accomplishes what you require. You can now modify your OLS function to whatever needed.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With