Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Using cross_val_predict against test data set

I'm confused about using cross_val_predict in a test data set.

I created a simple Random Forest model and used cross_val_predict to make predictions:

from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import cross_val_predict, KFold

lr = RandomForestClassifier(random_state=1, class_weight="balanced", n_estimators=25, max_depth=6)
kf = KFold(train_df.shape[0], random_state=1)
predictions = cross_val_predict(lr,train_df[features_columns], train_df["target"], cv=kf)
predictions = pd.Series(predictions)

I'm confused on the next step here. How do I use what is learnt above to make predictions on the test data set?

like image 298
Kabard Avatar asked Jan 10 '17 02:01


2 Answers

I don't think cross_val_score or cross_val_predict uses fit before predicting. It does it on the fly. If you look at the documentation (section, you'll see that they never mention fit anywhere.

like image 156
Sandeep Kumar Avatar answered Oct 22 '22 04:10

Sandeep Kumar

As @DmitryPolonskiy commented, the model has to be trained (with the fit method) before it can be used to predict.

# Train the model (a.k.a. `fit` training data to it).
lr.fit(train_df[features_columns], train_df["target"])
# Use the model to make predictions based on testing data.
y_pred = lr.predict(test_df[feature_columns])
# Compare the predicted y values to actual y values.
accuracy = (y_pred == test_df["target"]).mean()

cross_val_predict is a method of cross validation, which lets you determine the accuracy of your model. Take a look at sklearn's cross-validation page.

like image 34
jakub Avatar answered Oct 22 '22 02:10
