Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient algorithm for predicting from a decision tree (using R)

I'm working on a modification of Brieman's random forest procedure (and I don't know C/C++), so I've written my own RF variant from scratch in R. The difference between my procedure and the standard one is basically just in how the splitpoints and the values in terminal nodes are calculated -- once I have a tree in a forest, it can be thought of quite similarly to a tree from a typical RF algorithm.

My problem is that prediction from it is quite slow, and I'm having trouble thinking of ways of making it faster.

A test tree object is linked here, and some test data is linked here. You can download it directly, or you can load it below if you have repmis installed. They are called testtree and sampx.

library(repmis)
testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8")
sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")

EDIT: Somehow I still haven't gotten around to really learning how to use github well. I've uploaded the needed files to a repository here -- apologies that I can't figure out how to get a permalink at the moment...

It looks like this (using a plot function I've written): enter image description here

Here's a bit about how the object is structured:

1> summary(testtree)
         Length Class      Mode   
nodes       7   -none-     list   
minsplit    1   -none-     numeric
X          29   data.frame list   
y        6719   -none-     numeric
weights  6719   -none-     numeric
oob      2158   -none-     numeric
1> summary(testtree$nodes)
     Length Class  Mode
[1,] 4      -none- list
[2,] 8      -none- list
[3,] 8      -none- list
[4,] 7      -none- list
[5,] 7      -none- list
[6,] 7      -none- list
[7,] 7      -none- list
1> summary(testtree$nodes[[1]])
         Length Class  Mode   
y        6719   -none- numeric
output         1   -none- numeric
Terminal    1   -none- logical
children    2   -none- numeric
1> testtree$nodes[[1]][2:4]
$output
[1] 40.66925

$Terminal
[1] FALSE

$children
[1] 2 3

1> summary(testtree$nodes[[2]])
           Length Class  Mode     
y          2182   -none- numeric  
parent        1   -none- numeric  
splitvar      1   -none- character
splitpoint    1   -none- numeric  
handedness    1   -none- character
children      2   -none- numeric  
output        1   -none- numeric  
Terminal      1   -none- logical  
1> testtree$nodes[[2]][2:8]
$parent
[1] 1

$splitvar
[1] "bizrev_allHH"

$splitpoint
    25% 
788.875 

$handedness
[1] "Left"

$children
[1] 4 5

$output
[1] 287.0085

$Terminal
[1] FALSE

output is the return value at that node -- I hope everything else is self-explanatory.

The prediction function that I've written works, but it is way too slow. Basically it "walks down the tree", observation by observation:

predict.NT = function(tree.obj, newdata=NULL){
    if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
    tree = tree.obj$nodes
    if (length(tree)==1){#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    }
    pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree)
    return(pred)
}

godowntree = function(x, tree, nn = 1){
    while (tree[[nn]]$Terminal == FALSE){
        fb = tree[[nn]]$children[1]
        sv = tree[[fb]]$splitvar
        sp = tree[[fb]]$splitpoint
        if (class(sp)=='factor'){
            if (as.character(x[names(x) == sv]) == sp){
                nn<-fb
            } else{
                nn<-fb+1
            }
        } else {
            if (as.character(x[names(x) == sv]) < sp){
                nn<-fb
            } else{
                nn<-fb+1
            }
        }
    }
    return(tree[[nn]]$output)
}

The problem is that it is really slow (when you consider that non-sample trees are bigger, and that I need to do this many, many times), even for a simple tree:

library(microbenchmark)
microbenchmark(predict.NT(testtree,sampx))
Unit: milliseconds
                        expr      min       lq     mean   median       uq
 predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274
     max neval
 40.4691   100

I got the idea from someone today that I could write a function-factory type of function (i.e.: one that generates closures, which I'm just learning about) to break down my tree into a bunch of nested if/else statements. I could then send the data through that, which might be faster than extracting data from the tree over and over again. I haven't written function function-generating function yet, but I hand-wrote the sort of output I'd get from it, and tested that:

predictif = function(x){
    if (x[names(x) == 'bizrev_allHH'] < 788.875){
        if (x[names(x) == 'male_head'] <.872){
            return(548)
        } else {
            return(165)
        }
    } else {
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
            return(-283)
        }else{
            return(-11.4)
        }
    }
}
predictif.NT = function(tree.obj, newdata=NULL){
    if (is.null(newdata)){X = tree.obj$X} else {X = newdata}
    tree = tree.obj$nodes
    if (length(tree)==1){#Return the mean for a stump
        return(rep(tree[[1]]$output,length(X)))
    }
    pred = apply(X = newdata, 1, predictif)
    return(pred)
}

microbenchmark(predictif.NT(testtree,sampx))
Unit: milliseconds
                          expr      min       lq     mean   median       uq
 predictif.CT(testtree, sampx) 12.77701 12.97551 14.21417 13.18939 13.67667
      max neval
 30.48373   100

A little faster, but not much!

I'd be really grateful for any ideas for speeding this up! Or, if the answer is "you really can't get this much faster without converting it to C/C++", that'd be valuable information as well (especially if you gave me some info about why that would be).

While I'd certainly appreciate answers in R, answers in pseudocode would also be really helpful.

Thanks!

like image 522
generic_user Avatar asked Mar 09 '26 15:03

generic_user


1 Answers

The secret to speeding up your function is vectorization. Instead of performing all your operations on each row individually, perform them on all rows at once.

Let's reconsider your predictif function

predictif = function(x){
    if (x[names(x) == 'bizrev_allHH'] < 788.875){
        if (x[names(x) == 'male_head'] <.872){
            return(548)
        } else {
            return(165)
        }
    } else {
        if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){
            return(-283)
        }else{
            return(-11.4)
        }
    }
}

This is a slow approach because it is applying all these operations on each individual instance. Function calls, if statements, and especially operations like names(x) == 'bizrev_allHH' all have a bit of overhead that adds up when you're doing it for each instance.

In contrast, simply comparing two numbers is very fast! So instead, write a vectorized version of the above.

predictif_fast <- function(newdata) {
  n1 <- newdata$bizrev_allHH < 788.875
  n2 <- newdata$male_head < .872
  n3 <- newdata$nondurable_exp_mo < 4190.965

  ifelse(n1, ifelse(n2, 548.55893, 165.15537),
             ifelse(n3, -283.35145, -11.40185))
}

Notice, and this is very important, that this function is not being passed one instance. It is meant to be passed your entire newdata. This works because the < and ifelse operations are each vectorized: when given a vector, they return a vector.

Let's compare your function and this new one:

> microbenchmark(predictif.NT(testtree, sampx),
                 predictif_fast(sampx))
Unit: microseconds
                          expr       min         lq     mean    median         uq
 predictif.NT(testtree, sampx) 12106.419 13144.2390 14684.46 13719.406 14593.1565
         predictif_fast(sampx)   189.093   213.6505   263.74   246.192   260.7895
       max neval cld
 79136.335   100   b
  2344.059   100  a 

Notice we got a 50x speedup through vectorization.

Incidentally, it is possible to speed this up substantially more (there are faster alternatives to ifelse if you get clever with indexing), but overall switching from "perform a function on each row" to "perform operations on entire vectors" gets you the largest speedup.


This doesn't solve your problem entirely since you need to perform these vectorized operations on a general tree, not just on this particular one. I won't solve the general version for you, but consider that you can rewrite your godowntree function so that it takes an entire data frame and performs its operations on the full one, not just one. Then, instead of having an if branch, keep a vector of what child each instance is currently on.

like image 165
David Robinson Avatar answered Mar 11 '26 05:03

David Robinson



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!