Usage of caret with gbm method for multiclass classification




I am solving a multiclass classification problem and trying to use Generalized Boosted Models (gbm package in R). The issue I faced: caret's train function with method="gbm" seems not to work with multiclass data properly. A simple example is presented below.

fitControl <- trainControl(method="repeatedcv",
gbmFit <- train(Species ~ ., data=iris,

The output is

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero

Yet if I try to use gbm without caret wrapper, I get nice results.

train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overall

FYI, code on line marked by ##! converts a matrix of class probabilities returned by predict.gbm to a factor of most probable classes. The output is

      Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull AccuracyPValue  McnemarPValue 
  9.111111e-01   8.666667e-01   7.877883e-01   9.752470e-01   3.333333e-01   8.467252e-16            NaN 

Any suggestions how to make caret work properly with gbm on multiclass data?


R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)

 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=C                 LC_NAME=C                 

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] e1071_1.6-1      class_7.3-5      gbm_2.0-8        survival_2.36-14 caret_5.15-61    reshape2_1.2.2   plyr_1.8        
 [8] lattice_0.20-13  foreach_1.4.0    cluster_1.14.3   compare_0.2-3   

loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3     iterators_1.0.6 stringr_0.6.2   tools_2.15.3   
2 Answers

This is an issue that I'm working on right now.

It would help if you posted the results of sessionInfo().

Also, getting the latest gbm off of https://code.google.com/p/gradientboostedmodels/ might solve the problem.


Update: Caret can do multi-class classification.

You should ensure that class label is in alpha-numeric format (starting with a letter).

For example: if you data has labels "1", "2", "3" then change these to "Seg1", "Seg2" and "Seg3", else caret with fail.

