Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R: rpart tree grows using two explanatory variables, but not after removing less important variable

Data: I'm using the "attrition" dataset from the rsample package.

Question: Using the attrition dataset and rpart library, I can grow a tree using the formula "Attrition ~ OverTime + JobRole", where OverTime is chosen as the first split. But when I try to grow the tree without the JobRole variable (i.e. "Attrition ~ OverTime"), the tree does not split and returns only the root node. This occurs using both the rpart function, as well as caret's train function with method = "rpart".

I'm confused by this, since I thought that the CART algorithm implemented in rpart chose the best variable to split in an iterative greedy fashion, and did not "look ahead" to see how the presence of other variables impacts its selection of the best split. If the algorithm selected OverTime as a worthwhile first split in the case with two explanatory variables, why doesn't it select OverTime as a worthwhile first split after removing the JobRole variable?

I'm using R version 3.4.2 and RStudio Version 1.1.442 with Windows 7.

Research: I found similar Stack Overflow questions here and here, but neither had complete answers.

As best I can tell, the rpart docs seem to say on page 5 that the rpart algorithm does not use "look ahead" rules:

One way around both of these problems is to use look-ahead rules; but these are computationally very expensive. Instead rpart uses one of several measures of impurity, or diversity, of a node.

Also, similar descriptions here and here.

CODE: Here's a reprex. Any insight would be great - thanks!

suppressPackageStartupMessages(library(rsample))                                                                                                           
#> Warning: package 'rsample' was built under R version 3.4.4
suppressPackageStartupMessages(library(rpart))                                                                                                             
suppressPackageStartupMessages(library(caret))                                                                                                             
suppressPackageStartupMessages(library(dplyr))                                                                                                             
#> Warning: package 'dplyr' was built under R version 3.4.3
suppressPackageStartupMessages(library(purrr))                                                                                                             

#################################################                                                                                                          

# look at data                                                                                                                                             
data(attrition)                                                                                                                                            
attrition_subset <- attrition %>% select(Attrition, OverTime, JobRole)                                                                                     
attrition_subset %>% glimpse()                                                                                                                             
#> Observations: 1,470
#> Variables: 3
#> $ Attrition <fctr> Yes, No, Yes, No, No, No, No, No, No, No, No, No, N...
#> $ OverTime  <fctr> Yes, No, Yes, Yes, No, No, Yes, No, No, No, No, Yes...
#> $ JobRole   <fctr> Sales_Executive, Research_Scientist, Laboratory_Tec...
map_dfr(.x = attrition_subset, .f = ~ sum(is.na(.x)))                                                                                                      
#> # A tibble: 1 x 3
#>   Attrition OverTime JobRole
#>       <int>    <int>   <int>
#> 1         0        0       0

#################################################                                                                                                          

# with rpart                                                                                                                                               
attrition_rpart_w_JobRole <- rpart(Attrition ~ OverTime + JobRole, data = attrition_subset, method = "class", cp = .01)                                    
attrition_rpart_w_JobRole                                                                                                                                  
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>  1) root 1470 237 No (0.83877551 0.16122449)  
#>    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
#>    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
#>      6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
#>      7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
#>       14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
#>       15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

attrition_rpart_wo_JobRole <- rpart(Attrition ~ OverTime, data = attrition_subset, method = "class", cp = .01)                                             
attrition_rpart_wo_JobRole                                                                                                                                 
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245) *

#################################################                                                                                                          

# with caret                                                                                                                                               
attrition_caret_w_JobRole_non_dummies <- train(x = attrition_subset[ , -1], y = attrition_subset[ , 1], method = "rpart", tuneGrid = expand.grid(cp = .01))
attrition_caret_w_JobRole_non_dummies$finalModel                                                                                                           
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>  1) root 1470 237 No (0.83877551 0.16122449)  
#>    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
#>    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
#>      6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
#>      7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
#>       14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
#>       15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

attrition_caret_w_JobRole <- train(Attrition ~ OverTime + JobRole, data = attrition_subset, method = "rpart", tuneGrid = expand.grid(cp = .01))            
attrition_caret_w_JobRole$finalModel                                                                                                                       
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245)  
#>   2) OverTimeYes< 0.5 1054 110 No (0.8956357 0.1043643) *
#>   3) OverTimeYes>=0.5 416 127 No (0.6947115 0.3052885)  
#>     6) JobRoleSales_Representative< 0.5 392 111 No (0.7168367 0.2831633) *
#>     7) JobRoleSales_Representative>=0.5 24   8 Yes (0.3333333 0.6666667) *

attrition_caret_wo_JobRole <- train(Attrition ~ OverTime, data = attrition_subset, method = "rpart", tuneGrid = expand.grid(cp = .01))                     
attrition_caret_wo_JobRole$finalModel                                                                                                                      
#> n= 1470 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 1470 237 No (0.8387755 0.1612245) *
like image 528
sdevine188 Avatar asked Nov 08 '22 07:11

sdevine188


1 Answers

This makes perfect sense. You have quite a bit of extra code above, so I will repeat the essential part.

library(rsample)
library(rpart)
data(attrition)

rpart(Attrition ~ OverTime + JobRole, data=attrition)
n= 1470 
node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 1470 237 No (0.83877551 0.16122449)  
   2) OverTime=No 1054 110 No (0.89563567 0.10436433) *
   3) OverTime=Yes 416 127 No (0.69471154 0.30528846)  
     6) JobRole=Healthcare_Representative,Manager,Manufacturing_Director,Research_Director 126  11 No (0.91269841 0.08730159) *
     7) JobRole=Human_Resources,Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative 290 116 No (0.60000000 0.40000000)  
      14) JobRole=Human_Resources,Research_Scientist,Sales_Executive 204  69 No (0.66176471 0.33823529) *
      15) JobRole=Laboratory_Technician,Sales_Representative 86  39 Yes (0.45348837 0.54651163) *

rpart(Attrition ~ OverTime, data=attrition)
n= 1470 
node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 1470 237 No (0.8387755 0.1612245) *

Take a look at the first model (with two variables). Just below the root we have:

1) root 1470 237 No (0.83877551 0.16122449)        
    2) OverTime=No 1054 110 No (0.89563567 0.10436433) *      
    3) OverTime=Yes 416 127 No (0.69471154 0.30528846)

The model goes on to split node 3 (OverTime=Yes) but only using JobRole. Since we don't have JobRole in the second model, rpart can't make those other splits. But notice that at both nodes 2 and 3, Attrition=No is the majority class. At node 3, 69.5% of the instances are No and 30.5% are yes. So for both nodes 2 and 3 we would predict No. Since the prediction is the same on either side of the split, the split is unnecessary and is pruned away. You only need the root node to predict all instances are No.

like image 81
G5W Avatar answered Nov 15 '22 07:11

G5W