Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Usage of caret with gbm method for multiclass classification

Tags:

r

r-caret

I am solving a multiclass classification problem and trying to use Generalized Boosted Models (gbm package in R). The issue I faced: caret's train function with method="gbm" seems not to work with multiclass data properly. A simple example is presented below.

library(gbm)
library(caret)
data(iris)
fitControl <- trainControl(method="repeatedcv",
                           number=5,
                           repeats=1,
                           verboseIter=TRUE)
set.seed(825)
gbmFit <- train(Species ~ ., data=iris,
                method="gbm",
                trControl=fitControl,
                verbose=FALSE)
gbmFit

The output is

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
...
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero

Yet if I try to use gbm without caret wrapper, I get nice results.

set.seed(1365)
train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overall

FYI, code on line marked by ##! converts a matrix of class probabilities returned by predict.gbm to a factor of most probable classes. The output is

      Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull AccuracyPValue  McnemarPValue 
  9.111111e-01   8.666667e-01   7.877883e-01   9.752470e-01   3.333333e-01   8.467252e-16            NaN 

Any suggestions how to make caret work properly with gbm on multiclass data?

UPD:

sessionInfo()
R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=C                 LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] e1071_1.6-1      class_7.3-5      gbm_2.0-8        survival_2.36-14 caret_5.15-61    reshape2_1.2.2   plyr_1.8        
 [8] lattice_0.20-13  foreach_1.4.0    cluster_1.14.3   compare_0.2-3   

loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3     iterators_1.0.6 stringr_0.6.2   tools_2.15.3   
like image 314
maruan Avatar asked Mar 23 '13 10:03

maruan


People also ask

Which method is used for multiclass classification?

One-Vs-Rest for Multi-Class Classification. One-vs-rest (OvR for short, also referred to as One-vs-All or OvA) is a heuristic method for using binary classification algorithms for multi-class classification. It involves splitting the multi-class dataset into multiple binary classification problems.

What is caret library in R?

One of the most powerful and popular packages is the caret library, which follows a consistent syntax for data preparation, model building, and model evaluation, making it easy for data science practitioners. Caret stands for classification and regression training and is arguably the biggest project in R.

What is multiclass classification problem?

In machine learning, multiclass or multinomial classification is the problem of classifying instances into one of three or more classes (classifying instances into one of two classes is called binary classification).


2 Answers

This is an issue that I'm working on right now.

It would help if you posted the results of sessionInfo().

Also, getting the latest gbm off of https://code.google.com/p/gradientboostedmodels/ might solve the problem.

Max

like image 86
topepo Avatar answered Oct 13 '22 07:10

topepo


Update: Caret can do multi-class classification.

You should ensure that class label is in alpha-numeric format (starting with a letter).

For example: if you data has labels "1", "2", "3" then change these to "Seg1", "Seg2" and "Seg3", else caret with fail.

like image 27
Apurva Dubey Avatar answered Oct 13 '22 08:10

Apurva Dubey