Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python: Predict the y value using Statsmodels - Linear Regression

I am using the statsmodels library of Python to predict the future balance using Linear Regression. The csv file is displayed below:

Year | Balance
3 | 30
8 | 57
9 | 64
13 | 72
3 | 36
6 | 43
11 | 59
21 | 90
1 | 20
16 | 83
It contains the 'Year' as the independent 'x' variable, while the 'Balance' is the dependent 'y' variable

Here's the code for Linear Regression for this data:

import pandas as pd
import statsmodels.api as sm
from statsmodels.formula.api import ols
import numpy as np
from matplotlib import pyplot as plt

import os
os.chdir('C:\Users\Admin\Desktop\csv')

cw = pd.read_csv('data-table.csv')
y=cw.Balance
X=cw.Year

X = sm.add_constant(X)  # Adds a constant term to the predictor

est = sm.OLS(y, X)
est = est.fit()
print est.summary()

est.params

X_prime = np.linspace(X.Year.min(), X.Year.max(), 100)[:, np.newaxis]
X_prime = sm.add_constant(X_prime)  # add constant as we did before

y_hat = est.predict(X_prime)


plt.scatter(X.Year, y, alpha=0.3)  # Plot the raw data
plt.xlabel("Year")
plt.ylabel("Total Balance")
plt.plot(X_prime[:, 1], y_hat, 'r', alpha=0.9)  # Add the regression line, colored in red
plt.show()

The question is how to predict the 'Balance' value, using Statsmodels when the value of 'Year'=10 ?

like image 670
User456898 Avatar asked May 02 '16 07:05

User456898


1 Answers

You can use the predict method from the result object est but in order to succesfully use it you have to use as formula

est = sm.ols("y ~ x", data =data).fit()
est.predict(exog=new_values) 

where new_values is a dictionary.

Check out this link.

like image 190
Diego Aguado Avatar answered Oct 22 '22 23:10

Diego Aguado