Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

H2O deeplearning with class imbalance

I am using the H2O deeplearning Feed Forward Deep Neural network for doing a binary classification. My classes are highly imbalanced and I want to use the parameters like

balance_classes, class_sampling_factors

Can any body give me a reproducible example about how to specifically intialize these parameters to handle class imbalance problems.

like image 523
Panchacookie Avatar asked Dec 24 '22 02:12

Panchacookie


1 Answers

First, here is the full, reproducible, example:

library(h2o)
h2o.init()

data(iris)  #Not required?
iris <- iris[1:120,] #Remove 60% of virginica
summary(iris$Species) #50/50/20

d <- as.h2o(iris)
splits = h2o.splitFrame(d,0.8,c("train","test"), seed=77)
train = splits[[1]]
test = splits[[2]]
summary(train$Species)  #41/41/14
summary(test$Species)  #9/9/6

m1 = h2o.randomForest(1:4, 5, train, model_id ="RF_defaults", seed=1)
h2o.confusionMatrix(m1)

m2 = h2o.randomForest(1:4, 5, train, model_id ="RF_balanced", seed=1,
  balance_classes = TRUE)
h2o.confusionMatrix(m2)

m3 = h2o.randomForest(1:4, 5, train, model_id ="RF_balanced", seed=1,
  balance_classes = TRUE,
  class_sampling_factors = c(1, 1, 2.5)
  )
h2o.confusionMatrix(m3)

The first lines initialize H2O, then I deliberately modify the iris data set to throw away 60% of one of the 3 classes, to create an imbalance.

The next few lines load that data into H2O, and create a 80%/20% train/test data split. The seed was chosen deliberately, so that in the training data virginica is 14.58% of the data, compared to 16.67% in the original data, and 25% in the test data.

I then train three random forest models. m1 is all defaults, and its confusion matrix looks like this:

           setosa versicolor virginica  Error     Rate
setosa         41          0         0 0.0000 = 0 / 41
versicolor      0         39         2 0.0488 = 2 / 41
virginica       0          1        13 0.0714 = 1 / 14
Totals         41         40        15 0.0312 = 3 / 96

Nothing to see here: it uses the data it finds.

Now here is the same output for m2, which switches on balance_classes. You can see it is over-sampled the virginica class to get them as balanced as possible. (The right-most columns says 41,41,40 instead of 41,41,14 as in the previous output.)

           setosa versicolor virginica  Error      Rate
setosa         41          0         0 0.0000 =  0 / 41
versicolor      0         41         0 0.0000 =  0 / 41
virginica       0          2        38 0.0500 =  2 / 40
Totals         41         43        38 0.0164 = 2 / 122

In m3 we still switch on balance_classes, but also tell it the truth of the situation. I.e. that the actual data is 16.67% virginica, not the 14.58% it sees in the train data. The confusion matrix for m3 shows that it therefore turned the 14 virginica samples into 37 samples instead of 40 samples.

           setosa versicolor virginica  Error      Rate
setosa         41          0         0 0.0000 =  0 / 41
versicolor      0         41         0 0.0000 =  0 / 41
virginica       0          2        35 0.0541 =  2 / 37
Totals         41         43        35 0.0168 = 2 / 119

How did I know to write c(1, 1, 2.5), and not c(2.5, 1, 1) or c(1, 2.5, 1) ? The docs say it must be in "lexicographic order". You can find out what that order is with:

h2o.levels(train$Species)

which tells me:

[1] "setosa"     "versicolor" "virginica"

The opinion bit: balance_classes is good to switch on, but class_sampling_factors should only be used when you have a really good reason to believe that your training data is not representative.

NOTE: Code and explanation adapted from my upcoming book, Practical Machine Learning with H2O.

like image 131
Darren Cook Avatar answered Jan 07 '23 06:01

Darren Cook