Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

result of rpart is a root, but data shows Information Gain

I have a dataset with an event rate of less than 3% (i.e. there are about 700 records with class 1 and 27000 records with class 0).

ID          V1  V2      V3  V5      V6  Target
SDataID3    161 ONE     1   FOUR    0   0
SDataID4    11  TWO     2   THREE   2   1
SDataID5    32  TWO     2   FOUR    2   0
SDataID7    13  ONE     1   THREE   2   0
SDataID8    194 TWO     2   FOUR    0   0
SDataID10   63  THREE   3   FOUR    0   1
SDataID11   89  ONE     1   FOUR    0   0
SDataID13   78  TWO     2   FOUR    0   0
SDataID14   87  TWO     2   THREE   1   0
SDataID15   81  ONE     1   THREE   0   0
SDataID16   63  ONE     3   FOUR    0   0
SDataID17   198 ONE     3   THREE   0   0
SDataID18   9   TWO     3   THREE   0   0
SDataID19   196 ONE     2   THREE   2   0
SDataID20   189 TWO     2   ONE     1   0
SDataID21   116 THREE   3   TWO     0   0
SDataID24   104 ONE     1   FOUR    0   0
SDataID25   5   ONE     2   ONE     3   0
SDataID28   173 TWO     3   FOUR    0   0
SDataID29   5   ONE     3   ONE     3   0
SDataID31   87  ONE     3   FOUR    3   0
SDataID32   5   ONE     2   THREE   1   0
SDataID34   45  ONE     1   FOUR    0   0
SDataID35   19  TWO     2   THREE   0   0
SDataID37   133 TWO     2   FOUR    0   0
SDataID38   8   ONE     1   THREE   0   0
SDataID39   42  ONE     1   THREE   0   0
SDataID43   45  ONE     1   THREE   1   0
SDataID44   45  ONE     1   FOUR    0   0
SDataID45   176 ONE     1   FOUR    0   0
SDataID46   63  ONE     1   THREE   3   0

I am trying to find out split using the decision tree. But result of tree is only 1 root.

> library(rpart)
> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class")
> printcp(tree)

Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class")

Variables actually used in tree construction:
character(0)

Root node error: 749/18239 = 0.041066

n= 18239 

  CP nsplit rel error xerror xstd
1  0      0         1      0    0

After reading most of the resources on StackOverflow, I loosened/tweaked the control parameters which gave me the desired decision tree.

> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class" ,control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))
> printcp(tree)

Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), 
    method = "class", control = rpart.control(minsplit = 1, minbucket = 2, 
        cp = 2e-05))

Variables actually used in tree construction:
[1] V5         V2                     V1          
[4] V3         V6

Root node error: 749/18239 = 0.041066

n= 18239 

          CP nsplit rel error xerror     xstd
1 0.00024275      0   1.00000 1.0000 0.035781
2 0.00019073     20   0.99466 1.0267 0.036235
3 0.00016689     34   0.99199 1.0307 0.036302
4 0.00014835     54   0.98798 1.0334 0.036347
5 0.00002000     63   0.98665 1.0427 0.036504

When I pruned the tree it resulted in a tree with a single node.

> pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])
> printcp(pruned.tree)

Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), 
    method = "class", control = rpart.control(minsplit = 1, minbucket = 2, 
        cp = 2e-05))

Variables actually used in tree construction:
character(0)

Root node error: 749/18239 = 0.041066

n= 18239 

          CP nsplit rel error xerror     xstd
1 0.00024275      0         1      1 0.035781

The tree should not be giving out only root node because mathematically, on a given node (example provided) we are getting Information Gain. I don't know if I am making a mistake by pruning or there is an issue with rpart in handling low event rate dataset?

NODE    p       1-p     Entropy         Weights         Ent*Weight      # Obs
Node 1  0.032   0.968   0.204324671     0.351398601     0.071799404     10653
Node 2  0.05    0.95    0.286396957     0.648601399     0.185757467     19663

Sum(Ent*wght)       0.257556871 
Information gain    0.742443129 
like image 202
Rachit Jain Avatar asked Oct 31 '17 06:10

Rachit Jain


People also ask

What does rpart () in R do?

Rpart is a powerful machine learning library in R that is used for building classification and regression trees. This library implements recursive partitioning and is very easy to use.

What is rpart decision tree?

Data Visualization using R Programming Decision tree is a graph to represent choices and their results in form of a tree. The nodes in the graph represent an event or choice and the edges of the graph represent the decision rules or conditions. It is mostly used in Machine Learning and Data Mining applications using R.

How does the rpart function decide which variables to include?

The rpart algorithm works by splitting the dataset recursively, which means that the subsets that arise from a split are further split until a predetermined termination criterion is reached.

What does rpart stand for?

rpart: Recursive Partitioning and Regression Trees.


1 Answers

The data you provided does not reflect the ratio of the two target classes, so I've tweaked the data to better reflect that (see Data section):

> prop.table(table(train$Target))

         0          1 
0.96707581 0.03292419 

> 700/27700
[1] 0.02527076

The ratios are now relatively close...

library(rpart)
tree <- rpart(Target ~ ., data=train, method="class")
printcp(tree)

Results in:

Classification tree:
rpart(formula = Target ~ ., data = train, method = "class")

Variables actually used in tree construction:
character(0)

Root node error: 912/27700 = 0.032924

n= 27700 

  CP nsplit rel error xerror xstd
1  0      0         1      0    0

Now, the reason that you are seeing only the root node for your first model, is probably due to the fact that you have extremely imbalanced target classes, and so, your independent variables could not provide enough information to grow the tree. My sample data has 3.3% event rate, but yours has only about 2.5%!

As you have mentioned, there is a way to force rpart to grow the tree. That is to override the default complexity parameter (cp). The complexity measure is a combination of the size of the tree and how well the tree separates the target classes. From ?rpart.control, "Any split that does not decrease the overall lack of fit by a factor of cp is not attempted". This means that your model at this point does not have a split beyond the root node that decreases the complexity level enough for rpart to take into consideration. We can relax this threshold of what is considered "enough" by either setting a low or a negative cp (negative cp basically forces the tree to grow to its full size).

tree <- rpart(Target ~ ., data=train, method="class" ,parms = list(split = 'information'), 
              control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))
printcp(tree)

Results in:

Classification tree:
rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"), 
    control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))

Variables actually used in tree construction:
[1] ID V1 V2 V3 V5 V6

Root node error: 912/27700 = 0.032924

n= 27700 

           CP nsplit rel error xerror     xstd
1  4.1118e-04      0   1.00000 1.0000 0.032564
2  3.6550e-04     30   0.98355 1.0285 0.033009
3  3.2489e-04     45   0.97807 1.0702 0.033647
4  3.1328e-04    106   0.95504 1.0877 0.033911
5  2.7412e-04    116   0.95175 1.1031 0.034141
6  2.5304e-04    132   0.94737 1.1217 0.034417
7  2.1930e-04    149   0.94298 1.1458 0.034771
8  1.9936e-04    159   0.94079 1.1502 0.034835
9  1.8275e-04    181   0.93640 1.1645 0.035041
10 1.6447e-04    193   0.93421 1.1864 0.035356
11 1.5664e-04    233   0.92654 1.1853 0.035341
12 1.3706e-04    320   0.91228 1.2083 0.035668
13 1.2183e-04    344   0.90899 1.2127 0.035730
14 9.9681e-05    353   0.90789 1.2237 0.035885
15 2.0000e-05    364   0.90680 1.2259 0.035915

As you can see, the tree has grown to a size that reduces the complexity level by a minimum of cp. Two things to note:

  1. At zero nsplit, CP is already as low as 0.0004, where as the default cp in rpart is set to 0.01.
  2. Starting from nsplit == 0, the cross validation error (xerror) increases as you increase the number of splits.

Both of these indicate that your model is overfitting the data at nsplit == 0 and beyond, since adding more independent variables into your model does not add enough information (insufficient reduction in CP) to reduce the cross validation error. With this being said, your root node model is the best model in this case, which explains why your initial model has only the root node.

pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])
printcp(pruned.tree)

Results in:

Classification tree:
rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"), 
    control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))

Variables actually used in tree construction:
character(0)

Root node error: 912/27700 = 0.032924

n= 27700 

          CP nsplit rel error xerror     xstd
1 0.00041118      0         1      1 0.032564

As for the pruning part, it is now clearer why your pruned tree is the root node tree, since a tree that goes beyond 0 splits has increasing cross validation error. Taking the tree with the minimum xerror would leave you with root node tree as expected.

Information gain basically tells you how much "information" is added for each split. So technically, every split has some degree of information gain since you are adding more variables into your model (information gain is always non-negative). What you should think about is whether that additional gain (or no gain) reduces the errors enough for you to warrant a more complex model. Hence, the tradeoff between bias and variance.

In this case, it doesn't really make sense for you to reduce cp and later prune the resulting tree. since by setting a low cp, you are telling rpart to make splits even if it overfits, while pruning "cuts" all the nodes that overfits.

Data:

Note that I am shuffling the rows for each column and sample instead of sampling the row indices. This is because the data you provided is probably not a random sample of your original dataset (likely biased), so I am basically randomly creating new observations with combinations of your existing rows which would hopefully reduce that bias.

init_train = structure(list(ID = structure(c(16L, 24L, 29L, 30L, 31L, 1L, 
2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 
17L, 18L, 19L, 20L, 21L, 22L, 23L, 25L, 26L, 27L, 28L), .Label = c("SDataID10", 
"SDataID11", "SDataID13", "SDataID14", "SDataID15", "SDataID16", 
"SDataID17", "SDataID18", "SDataID19", "SDataID20", "SDataID21", 
"SDataID24", "SDataID25", "SDataID28", "SDataID29", "SDataID3", 
"SDataID31", "SDataID32", "SDataID34", "SDataID35", "SDataID37", 
"SDataID38", "SDataID39", "SDataID4", "SDataID43", "SDataID44", 
"SDataID45", "SDataID46", "SDataID5", "SDataID7", "SDataID8"), class = "factor"), 
    V1 = c(161L, 11L, 32L, 13L, 194L, 63L, 89L, 78L, 87L, 81L, 
    63L, 198L, 9L, 196L, 189L, 116L, 104L, 5L, 173L, 5L, 87L, 
    5L, 45L, 19L, 133L, 8L, 42L, 45L, 45L, 176L, 63L), V2 = structure(c(1L, 
    3L, 3L, 1L, 3L, 2L, 1L, 3L, 3L, 1L, 1L, 1L, 3L, 1L, 3L, 2L, 
    1L, 1L, 3L, 1L, 1L, 1L, 1L, 3L, 3L, 1L, 1L, 1L, 1L, 1L, 1L
    ), .Label = c("ONE", "THREE", "TWO"), class = "factor"), 
    V3 = c(1L, 2L, 2L, 1L, 2L, 3L, 1L, 2L, 2L, 1L, 3L, 3L, 3L, 
    2L, 2L, 3L, 1L, 2L, 3L, 3L, 3L, 2L, 1L, 2L, 2L, 1L, 1L, 1L, 
    1L, 1L, 1L), V5 = structure(c(1L, 3L, 1L, 3L, 1L, 1L, 1L, 
    1L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 1L, 2L, 1L, 2L, 1L, 3L, 
    1L, 3L, 1L, 3L, 3L, 3L, 1L, 1L, 3L), .Label = c("FOUR", "ONE", 
    "THREE", "TWO"), class = "factor"), V6 = c(0L, 2L, 2L, 2L, 
    0L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 0L, 2L, 1L, 0L, 0L, 3L, 0L, 
    3L, 3L, 1L, 0L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 3L), Target = c(0L, 
    1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 
    0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L
    )), .Names = c("ID", "V1", "V2", "V3", "V5", "V6", "Target"
), class = "data.frame", row.names = c(NA, -31L))

set.seed(1000)
train = as.data.frame(lapply(init_train, function(x) sample(x, 27700, replace = TRUE)))
like image 140
acylam Avatar answered Nov 15 '22 19:11

acylam