Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extract split values from rpart object in R

I can't find the split values (or other data) for nodes in an rpart object. I see it with summary(sample_model) but not in the list or data frame

Some sample data

foo.df <- structure(list(type = c("fudai", "fudai", "fudai", "fudai", "fudai", 
                              "fudai", "fudai", "tozama", "fudai", "fudai", "tozama", "tozama", 
                              "fudai", "tozama", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "fudai", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "tozama", "fudai", "tozama", "tozama", "tozama", "tozama", "fudai", 
                              "fudai", "tozama", "tozama"), distance = c(12.5366985071383, 
                                                                         272.697138147139, 40.4780423740381, 109.806349869662, 147.781805212839, 
                                                                         89.4280438527415, 49.1425850803745, 555.414271440522, 119.365138867582, 
                                                                         182.902536555383, 310.019126513348, 277.122207392514, 214.510428881317, 
                                                                         235.111617874157, 104.494518693549, 50.7561853895564, 343.308898045237, 
                                                                         151.796857505073, 36.0391449169937, 30.8214406651022, 343.294467363406, 
                                                                         135.841501028422, 154.798119311647, 317.739208576563, 3.33794280697559, 
                                                                         98.9182898110913, 422.915369767251, 194.957988642709, 87.6548263591412, 
                                                                         187.571370158631, 236.292608259126, 17.915709270268, 193.548578374405, 
                                                                         262.190146422316, 21.6219797945323, 121.199009527283, 261.670997612517, 
                                                                         202.2051991431, 125.418459536787, 275.964068539003, 190.112226847932, 
                                                                         20.1753302760961, 488.80323504215, 579.25515722891, 233.500797034697, 
                                                                         207.588349435329, 183.770003408524, 168.739293254246, 313.140075747773, 
                                                                         131.69228390613), age = c(1756, 1711, 1712, 1746, 1868, 1866, 
                                                                                                   1682, 1617, 1771, 1764, 1672, 1636, 1864, 1704, 1762, 1868, 1694, 
                                                                                                   1749, 1703, 1616, 1691, 1702, 1723, 1683, 1742, 1691, 1623, 1721, 
                                                                                                   1704, 1745, 1749, 1723, 1639, 1661, 1843, 1845, 1669, 1698, 1698, 
                                                                                                   1664, 1868, 1633, 1783, 1642, 1615, 1648, 1734, 1758, 1725, 1635
                                                                         )), class = c("tbl_df", "tbl", "data.frame"), row.names = c(NA, 
                                                                                                                                     -50L))

And a basic model

library("rpart")
sample_model <- rpart(formula = type ~ ., 
                  data = sample_data, 
                  method = "class",
                  control = rpart.control(xval = 50, minbucket = 5, cp = 0.05),
                  parms = list(split = "gini"))

The rpart documentation say that there's supposed to be a column(s) in sample_model$frame called "splits" but it's not there. To quote: "splits, a two column matrix of left and right split labels for each node" https://www.rdocumentation.org/packages/rpart/versions/4.1-15/topics/rpart.object

Where are those columns in in sample_model$frame or sample_model? However, I see the data I want in

summary(sample_model)

What's going on?

like image 903
Mark R Avatar asked Oct 25 '25 14:10

Mark R


1 Answers

The docs are indeed outdated. Here is an extractor derived by inspecting summary.rpart function:


rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

Filter on type == "main" to get only the main splits:

> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart_splits(fit)
      var      type node ix    left count ncat    improve index       adj
1   Start      main    1  1  >= 8.5    81    1 6.76232996   8.5 0.0000000
2  Number   primary    1  1   < 5.5    81   -1 2.86679493   5.5 0.0000000
3     Age   primary    1  1  < 39.5    81   -1 2.25021152  39.5 0.0000000
4  Number surrogate    1  1   < 6.5     0   -1 0.80246914   6.5 0.1578947
5   Start      main    2  2 >= 14.5    62    1 1.02052786  14.5 0.0000000
6     Age   primary    2  2    < 55    62   -1 0.68486352  55.0 0.0000000
7  Number   primary    2  2   < 4.5    62   -1 0.29753321   4.5 0.0000000
8  Number surrogate    2  2   < 3.5     0   -1 0.64516129   3.5 0.2413793
9     Age surrogate    2  2    < 16     0   -1 0.59677419  16.0 0.1379310
10    Age      main    5  4    < 55    33   -1 1.24675325  55.0 0.0000000
11  Start   primary    5  4 >= 12.5    33    1 0.28877005  12.5 0.0000000
12 Number   primary    5  4  >= 3.5    33    1 0.17532468   3.5 0.0000000
13  Start surrogate    5  4   < 9.5     0   -1 0.75757576   9.5 0.3333333
14 Number surrogate    5  4  >= 5.5     0    1 0.69696970   5.5 0.1666667
15    Age      main   11  6  >= 111    21    1 1.71428571 111.0 0.0000000
16  Start   primary   11  6 >= 12.5    21    1 0.79365079  12.5 0.0000000
17 Number   primary   11  6  >= 3.5    21    1 0.07142857   3.5 0.0000000

like image 116
VitoshKa Avatar answered Oct 27 '25 03:10

VitoshKa