Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting the observations in a rpart's node (i.e.: CART)

I would like to inspect all the observations that reached some node in an rpart decision tree. For example, in the following code:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

I would like to see all the observations in node (5) (i.e.: the 33 observations for which Start>=8.5 & Start< 14.5). Obviously I could manually get to them. But I would like to have some function like (say) "get_node_date". For which I could just run get_node_date(5) - and get the relevant observations.

Any suggestions on how to go about this?

like image 917
Tal Galili Avatar asked Apr 20 '16 15:04

Tal Galili


2 Answers

There seems to be no such function which enables an extraction of the observations from a specific node. I would solve it as follows: first determine which rule/s is/are used for the node you are insterested in. You can use path.rpart for it. Then you could apply the rule/s one after the other to extract the observations.

This approach as a function:

get_node_date <- function(tree = fit, node = 5){
  rule <- path.rpart(tree, node)
  rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
  ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
  kyphosis[ind,]
  }

For node 5 you get:

get_node_date()

 node number: 5 
   root
   Start>=8.5
   Start< 14.5
   Kyphosis Age Number Start
2    absent 158      3    14
10  present  59      6    12
11  present  82      5    14
14   absent   1      4    12
18   absent 175      5    13
20   absent  27      4     9
23  present  96      3    12
26   absent   9      5    13
28   absent 100      3    14
32   absent 125      2    11
33   absent 130      5    13
35   absent 140      5    11
37   absent   1      3     9
39   absent  20      6     9
40  present  91      5    12
42   absent  35      3    13
46  present 139      3    10
48   absent 131      5    13
50   absent 177      2    14
51   absent  68      5    10
57   absent   2      3    13
59   absent  51      7     9
60   absent 102      3    13
66   absent  17      4    10
68   absent 159      4    13
69   absent  18      4    11
71   absent 158      5    14
72   absent 127      4    12
74   absent 206      4    10
77  present 157      3    13
78   absent  26      7    13
79   absent 120      2    13
81   absent  36      4    13
like image 168
DatamineR Avatar answered Sep 23 '22 16:09

DatamineR


Two years after original post, but may be of use to others. Node assignments for training observations in rpart can be obtained from $where:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where

As a function:

get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
  data[which(fit$where == node.number),]  
}
get_node()

This works for training observations only though, not for new observations.

like image 20
Marjolein Fokkema Avatar answered Sep 22 '22 16:09

Marjolein Fokkema