Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART
package. Im extracting tree information from BART
package and turning it into a data frame that resembles the one shown below (i.e., df
). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.
For example, my data frame of trees looks like this:
library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))
Visually, these trees would look like:
The trees are being drawn left-first when traversing down df
. Additionally, all splits are binary splits. So each node will have 2 children.
So, if we create some data that looks like this:
set.seed(100)
dat <- data.frame( x1 = runif(10),
x2 = runif(10),
x3 = runif(10),
x4 = runif(10),
x5 = runif(10)
)
Im trying to find which of the observations of dat
fall into which node?
Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:
lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]
namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]
I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??
It seems that we can do "rolling splits" to get what you are looking for. The logic is as follows.
dat
.variableName
and splitValue
, if they are not NA
s, split the top dataframe on that stack into two sub dataframes identified by variableName <= splitValue
and variableName > splitValue
(the former on top of the latter); if they are NA
s, then simply pop the top dataframe.Here is the code. Note that this kind of state-dependent computation is hard to vectorize. It's thus not what R is good at. If you have a lot of trees and the code performance becomes a serious concern, I'd suggest rewriting the code below using Rcpp
.
eval_node <- function(df, x, v) {
out <- vector("list", length(x))
stk <- vector("list", sum(is.na(x)))
pos <- 1L
stk[[pos]] <- df
for (i in seq_along(x)) {
if (!is.na(x[[i]])) {
subs <- pos + c(0L, 1L)
stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
names(stk)[subs] <- trimws(paste0(
names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
), "left", ",")
out[[i]] <- rev(stk[subs])
pos <- pos + 1L
} else {
out[[i]] <- stk[pos]
stk[[pos]] <- NULL
pos <- pos - 1L
}
}
out
}
Then you can apply the function like this.
library(dplyr)
df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))
Output
# A tibble: 15 x 4
# Groups: treeNo [3]
variableName splitValue treeNo node
<chr> <dbl> <dbl> <list>
1 x2 0.542 1 <named list [2]>
2 x1 0.126 1 <named list [2]>
3 NA NA 1 <named list [1]>
4 NA NA 1 <named list [1]>
5 NA NA 1 <named list [1]>
6 x2 0.655 2 <named list [2]>
7 NA NA 2 <named list [1]>
8 NA NA 2 <named list [1]>
9 x5 0.418 3 <named list [2]>
10 x4 0.234 3 <named list [2]>
11 NA NA 3 <named list [1]>
12 NA NA 3 <named list [1]>
13 x3 0.747 3 <named list [2]>
14 NA NA 3 <named list [1]>
15 NA NA 3 <named list [1]>
, where node
looks like this
[[1]]
[[1]]$`x2<=0.542`
x1 x2 x3 x4 x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[1]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[2]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
[[4]]
[[4]]$`x2<=0.542,x1>0.126`
x1 x2 x3 x4 x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[5]]
[[5]]$`x2>0.542`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[6]]
[[6]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[6]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[7]]
[[7]]$`x2<=0.6547`
x1 x2 x3 x4 x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
[[8]]
[[8]]$`x2>0.6547`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]
[[9]]$`x5<=0.418`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[9]]$`x5>0.418`
x1 x2 x3 x4 x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[10]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
x1 x2 x3 x4 x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859
[[12]]
[[12]]$`x5<=0.418,x4>0.234`
x1 x2 x3 x4 x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
[[13]]
[[13]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[13]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
[[14]]
[[14]]$`x5>0.418,x3<=0.747`
x1 x2 x3 x4 x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
[[15]]
[[15]]$`x5>0.418,x3>0.747`
x1 x2 x3 x4 x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With