Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find the data elements in a data frame that pass the rule for a node in a tree model?

Tags:

r

rpart

So I have used the rpart package to create a tree model and I found an interesting rule and wondered if there was an easy way to see which observations in that data frame pass that rule.

It seems very tedious to use path.rpart to find the path it took down the tree, and manually enter those filters into the data frame to look for them. Is there a method where I can pass a tree and/or a node, and a data frame and return all the elements in that frame that ended at that node?

like image 612
mandroid Avatar asked Dec 26 '22 08:12

mandroid


2 Answers

I modified the code in path.rpart to return the subset of the data that falls within a particular node rather than returning information about that node. It works by either clicking on the plot or by passing nodes just as the path.rpart function does. Here is the code

subset.rpart <- function (tree, df, nodes) {
    if (!inherits(tree, "rpart")) 
        stop("Not a legitimate \"rpart\" object")
    stopifnot(nrow(df)==length(tree$where))
    frame <- tree$frame
    n <- row.names(frame)
    node <- as.numeric(n)

    if (missing(nodes)) {
        xy <- rpart:::rpartco(tree)
        i <- identify(xy, n = 1L, plot = FALSE)
        if(i> 0L) {
             return( df[tree$where==i, ] )
        } else {
            return(df[0,])
        }
    }
    else {
        if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
            return(df[0,])
        return ( df[tree$where %in% as.numeric(nodes), ] )
    }
}

I will use it on some sample data from the package

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)

rpart tree plot

And then to find the observations at a particular node, run

subset.rpart(fit, kyphosis)

and click on a node on the plot. After you do, all the observations at that node will be returned. You must use the same data.frame that was used for modeling for this to work properly. Rather than clicking on a point, you can also pass in a node name that you you discover with path.rpart

# path.rpart(fit)  
#  node number: 10  ---> looks interesting
#    root
#    Start>=8.5
#    Start< 14.5
#    Age< 55

subset.rpart(fit, kyphosis, 10)
#    Kyphosis Age Number Start
# 14   absent   1      4    12
# 20   absent  27      4     9
# 26   absent   9      5    13
# 37   absent   1      3     9
# 39   absent  20      6     9
# 42   absent  35      3    13
# 57   absent   2      3    13
# 59   absent  51      7     9
# 66   absent  17      4    10
# 69   absent  18      4    11
# 78   absent  26      7    13
# 81   absent  36      4    13
like image 191
MrFlick Avatar answered Jan 30 '23 23:01

MrFlick


#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
  nodes = as.numeric(rownames(tree$frame))
  nodes = log(nodes, 2)
  lower = log(node, 2)
  upper = log(node + 1, 2)
  a = floor(lower)
  lower_ = lower - a
  upper_  = upper - a
  nodes_ = nodes %% 1
  w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
  tree$where %in% w
}



#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
  df[subset_rpart(tree, node), ]
}
like image 35
chen.wq Avatar answered Jan 30 '23 23:01

chen.wq