Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ctree() - How to get the list of splitting conditions for each terminal node?

I have an output from ctree() (party package) that looks like the following. How do I get the list of splitting conditions for each terminal node, like like sns <= 0, dta <= 1; sns <= 0, dta > 1 and so on?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6

Thanks

like image 487
Sriram Murali Avatar asked Jan 29 '14 21:01

Sriram Murali


2 Answers

This function should do the job

 CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
  # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
  NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])


  # Getting the weigths for that node
  NodeWeights <- nodes(ct, Node)[[1]]$weights


  # Finding the path
  Path <- NULL
  for (i in NonTerminalNodes){
    if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
  }

  # Finding the splitting creteria for that path
  Path2 <- SB <- NULL

  for(i in 1:length(Path)){
    if(i == length(Path)) {
      n <- nodes(ct, Node)[[1]]
    } else {n <- nodes(ct, Path[i + 1])[[1]]}

    if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
      SB <- "<="
    } else {SB <- ">"}
    Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                 SB,
                                 as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                   collapse = ", ")
  }

  # Output
  ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}

Testing

library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result 

##   Node                               Path
## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2    3            Temp <= 82, Wind <= 6.9
## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
## 4    9             Temp > 82, Wind > 10.3
## 5    8            Temp > 82, Wind <= 10.3
like image 174
David Arenburg Avatar answered Sep 30 '22 00:09

David Arenburg


If you use the new recommended partykit implementation of ctree() rather than the old party package, then you can use the function .list.rules.party(). This is not yet officially exported, yet, but can be leveraged to extract the desired information.

library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
##                                      3                                      5 
##             "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" 
##                                      6                                      8 
##  "Temp <= 82 & Wind > 6.9 & Temp > 77"             "Temp > 82 & Wind <= 10.3" 
##                                      9 
##              "Temp > 82 & Wind > 10.3" 
like image 23
Achim Zeileis Avatar answered Sep 30 '22 00:09

Achim Zeileis