Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Error when using predict() on a randomForest object trained with caret's train() using formula

Using R 3.2.0 with caret 6.0-41 and randomForest 4.6-10 on a 64-bit Linux machine.

When trying to use the predict() method on a randomForest object trained with the train() function from the caret package using a formula, the function returns an error. When training via randomForest() and/or using x= and y= rather than a formula, it all runs smoothly.

Here is a working example:

library(randomForest)
library(caret)

data(imports85)
imp85     <- imports85[, c("stroke", "price", "fuelType", "numOfDoors")]
imp85     <- imp85[complete.cases(imp85), ]
imp85[]   <- lapply(imp85, function(x) if (is.factor(x)) x[,drop=TRUE] else x) ## Drop empty levels for factors.

modRf1  <- randomForest(numOfDoors~., data=imp85)
caretRf <- train( numOfDoors~., data=imp85, method = "rf" )
modRf2  <- caretRf$finalModel
modRf3  <- randomForest(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"])
caretRf <- train(x=imp85[,c("stroke", "price", "fuelType")], y=imp85[, "numOfDoors"], method = "rf")
modRf4  <- caretRf$finalModel

p1      <- predict(modRf1, newdata=imp85)
p2      <- predict(modRf2, newdata=imp85)
p3      <- predict(modRf3, newdata=imp85)
p4      <- predict(modRf4, newdata=imp85)

Among the last 4 lines, only the second one p2 <- predict(modRf2, newdata=imp85) returns the following error:

Error in predict.randomForest(modRf2, newdata = imp85) : 
variables in the training data missing in newdata

It seems that the reason for this error is that the predict.randomForest method uses rownames(object$importance) to determine the name of the variables used to train the random forest object. And when looking at

rownames(modRf1$importance)
rownames(modRf2$importance)
rownames(modRf3$importance)
rownames(modRf4$importance)

We see:

[1] "stroke"   "price"    "fuelType"
[1] "stroke"   "price"    "fuelTypegas"
[1] "stroke"   "price"    "fuelType"
[1] "stroke"   "price"    "fuelType"

So somehow, when using the caret train() function with a formula changes the name of the (factor) variables in the importance field of the randomForest object.

Is it really an inconsistency between the formula and and non-formula version of the caret train() function? Or am I missing something?

like image 736
Adrien Combaz Avatar asked May 07 '15 09:05

Adrien Combaz


1 Answers

First, almost never use the $finalModel object for prediction. Use predict.train. This is one good example of why.

There is some inconsistency between how some functions (including randomForest and train) handle dummy variables. Most functions in R that use the formula method will convert factor predictors to dummy variables because their models require numerical representations of the data. The exceptions to this are tree- and rule-based models (that can split on categorical predictors), naive Bayes, and a few others.

So randomForest will not create dummy variables when you use randomForest(y ~ ., data = dat) but train (and most others) will using a call like train(y ~ ., data = dat).

The error occurs because fuelType is a factor. The dummy variables created by train don't have the same names so predict.randomForest can't find them.

Using the non-formula method with train will pass the factor predictors to randomForest and everything will work.

TL;DR

Use the non-formula method with train if you want the same levels or use predict.train

like image 77
topepo Avatar answered Nov 12 '22 17:11

topepo