Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python Statsmodels: Using SARIMAX with exogenous regressors to get predicted mean and confidence intervals

I'm using statsmodels.tsa.SARIMAX() to train a model with exogenous variables. Is there an equivalent of get_prediction() when a model is trained with exogenous variables so that the object returned contains the predicted mean and confidence interval rather than just an array of predicted mean results? The predict() and forecast() methods take exogenous variables, but only return the predicted mean value.

SARIMA_model = sm.tsa.SARIMAX(endog=y_train.astype('float64'),
                          exog=ExogenousFeature_train.values.astype('float64'), 
                          order=(1,0,0),
                          seasonal_order=(2,1,0,7), 
                          simple_differencing=False)

model_results = SARIMA_model.fit()

pred = model_results.predict(start=train_end_date,
                               end=test_end_date,
                               exog=ExogenousFeature_test.values.astype('float64').reshape(343,1),
                               dynamic=False)

pred here is an array of predicted values rather than an object containing predicted mean values and confidence intervals that you would get if you ran get_predict(). Note, get_predict() does not take exogenous variables.

My version of statsmodels is 0.8

like image 549
Kishan Manani Avatar asked Sep 26 '16 10:09

Kishan Manani


1 Answers

There has been some backward compatibility related issues due to which full results (with pred intervals etc) are not being exposed.

To get you what you want now: Use get_prediction and get_forecast functions with parameters described below

    pred_res = sarimax_model.get_prediction(exog=ExogenousFeature_train.values.astype('float64'), full_results=True,alpha=0.05)
    pred_means = pred_res.predicted_mean
    # Specify your prediction intervals by alpha parameter. alpha=0.05 implies 95% CI
    pred_cis = pred_res.conf_int(alpha=0.05)

    # You can then plot it (import matplotlib first)
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(1,1,1)
    #Actual data
    ax.plot(y_train.astype('float64'), '--', color="blue", label='data')
    # Means
    ax.plot(pred_means, lw=1, color="black", alpha=0.5, label='SARIMAX')
    ax.fill_between(pred_means.index, pred_cis.iloc[:, 0], pred_cis.iloc[:, 1], alpha=0.05)
    ax.legend(loc='upper right')
    plt.draw()

For more info, go to:

  • https://github.com/statsmodels/statsmodels/issues/2823
  • Solution by the author: http://www.statsmodels.org/dev/examples/notebooks/generated/statespace_local_linear_trend.html
like image 98
Vinay Kolar Avatar answered Nov 15 '22 14:11

Vinay Kolar