So I have used the rpart package to create a tree model and I found an interesting rule and wondered if there was an easy way to see which observations in that data frame pass that rule.
It seems very tedious to use path.rpart to find the path it took down the tree, and manually enter those filters into the data frame to look for them. Is there a method where I can pass a tree and/or a node, and a data frame and return all the elements in that frame that ended at that node?
I modified the code in path.rpart
to return the subset of the data that falls within a particular node rather than returning information about that node. It works by either clicking on the plot or by passing nodes just as the path.rpart
function does. Here is the code
subset.rpart <- function (tree, df, nodes) {
if (!inherits(tree, "rpart"))
stop("Not a legitimate \"rpart\" object")
stopifnot(nrow(df)==length(tree$where))
frame <- tree$frame
n <- row.names(frame)
node <- as.numeric(n)
if (missing(nodes)) {
xy <- rpart:::rpartco(tree)
i <- identify(xy, n = 1L, plot = FALSE)
if(i> 0L) {
return( df[tree$where==i, ] )
} else {
return(df[0,])
}
}
else {
if (length(nodes <- rpart:::node.match(nodes, node)) == 0L)
return(df[0,])
return ( df[tree$where %in% as.numeric(nodes), ] )
}
}
I will use it on some sample data from the package
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)
And then to find the observations at a particular node, run
subset.rpart(fit, kyphosis)
and click on a node on the plot. After you do, all the observations at that node will be returned. You must use the same data.frame
that was used for modeling for this to work properly. Rather than clicking on a point, you can also pass in a node name that you you discover with path.rpart
# path.rpart(fit)
# node number: 10 ---> looks interesting
# root
# Start>=8.5
# Start< 14.5
# Age< 55
subset.rpart(fit, kyphosis, 10)
# Kyphosis Age Number Start
# 14 absent 1 4 12
# 20 absent 27 4 9
# 26 absent 9 5 13
# 37 absent 1 3 9
# 39 absent 20 6 9
# 42 absent 35 3 13
# 57 absent 2 3 13
# 59 absent 51 7 9
# 66 absent 17 4 10
# 69 absent 18 4 11
# 78 absent 26 7 13
# 81 absent 36 4 13
#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
nodes = as.numeric(rownames(tree$frame))
nodes = log(nodes, 2)
lower = log(node, 2)
upper = log(node + 1, 2)
a = floor(lower)
lower_ = lower - a
upper_ = upper - a
nodes_ = nodes %% 1
w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
tree$where %in% w
}
#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
df[subset_rpart(tree, node), ]
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With