Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to predict new values using statsmodels.formula.api (python)

I trained the logistic model using the following, from breast cancer data and ONLY using one feature 'mean_area'

from statsmodels.formula.api import logit
logistic_model = logit('target ~ mean_area',breast)
result = logistic_model.fit()

There is a built in predict method in the trained model. However that gives the predicted values of all the training samples. As follows

predictions = result.predict()

Suppose I want the prediction for a new value say 30 How do I used the trained model to out put the value? (rather than reading the coefficients and computing manually)

like image 647
vishmay Avatar asked Aug 15 '16 14:08

vishmay


People also ask

How do you predict a value in Python?

Python predict() function enables us to predict the labels of the data values on the basis of the trained model. The predict() function accepts only a single argument which is usually the data to be tested.

What is statsmodels formula API in Python?

statsmodels. formula. api : A convenience interface for specifying models using formula strings and DataFrames. This API directly exposes the from_formula class method of models that support the formula API.

How predict function works in Python?

model. predict() : given a trained model, predict the label of a new set of data. This method accepts one argument, the new data X_new (e.g. model. predict(X_new) ), and returns the learned label for each object in the array.


2 Answers

You can provide new values to the .predict() model as illustrated in output #11 in this notebook from the docs for a single observation. You can provide multiple observations as 2d array, for instance a DataFrame - see docs.

Since you are using the formula API, your input needs to be in the form of a pd.DataFrame so that the column references are available. In your case, you could use something like .predict(pd.DataFrame({'mean_area': [1,2,3]}).

statsmodels .predict() uses the observations used for fitting only as default when no alternative is provided.

like image 161
Stefan Avatar answered Oct 13 '22 04:10

Stefan


import statsmodels.formula.api as smf


model = smf.ols('y ~ x', data=df).fit()

# Predict for a list of observations, list length can be 1 to many..**
prediction = model.get_prediction(exog=dict(x=[5,10,25])) 
prediction.summary_frame(alpha=0.05)
like image 24
Shu Zhang Avatar answered Oct 13 '22 03:10

Shu Zhang