How can I get ID (or name) of terminal node of rpart
model for every row? predict.rpart
can return only predicted class (number or factor) or class probability or some combination (using type="matrix"
) for classification tree.
I would like to do something like:
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id") # should return IDs of terminal nodes (e.g. 1-5) (does not work)
The partykit
package supports predict(..., type = "node")
, both in and out of sample. You can simply convert the rpart
object to use this:
library("partykit")
predict(as.party(fit), type = "node")
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9
## 5
## 5
table(predict(as.party(fit), type = "node"))
## 3 5 7 8 9
## 29 12 14 7 19
For that model there were 4 splits, yielding 5 "terminal nodes" or in the terminology used in rpart: <leaf>
s. I do not see why there should be 5 predictions for anything. The predictions are for particular cases and the leaves are the result of a variable number of the splits used to make those predictions. The numbers of rows in the original dataset that ended up in the leaves may be what you want, in which case these are ways of getting those numbers:
# Row-wise predicted class
fit$where
# counts of cases in leaves of prediction rules
table(fit$where)
3 5 7 8 9
29 12 14 7 19
In order to assemble the labels(fit)
that apply to a particular leaf, you would need to traverse the rule-tree and accumulate all the labels for all the splits that were applied to produce a particular leaf. You probably want to look at:
?print.rpart
?rpart.object
?text.rpart
?labels.rpart
The above method using $where pops up only the row number in the tree frame. And so some observation might be assigned node ID instead of leaf node ID when using kyphosis$ID = fit$where
To get the actual leaf node ID use the following:
MyID <- row.names(fit$frame)
kyphosis$ID <- MyID[fit$where]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With