Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there any faster alternative to stats:uniroot function in R?

Tags:

r

data.table

I am running stats::uniroot function on one million rows in data.table. Here is a toy example -

library(data.table)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 50000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

print(system.time(
dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
))

On the above code, the time taken is close to 8 seconds for 50,000 rows.

Is there any faster alternative to the uniroot function which can reduce this time considerably?

like image 447
Saurabh Avatar asked Nov 06 '21 00:11

Saurabh


3 Answers

160 seconds (1e6/5e4 * 8) doesn't sound so bad to me for a million rows (although maybe your real function is much slower than the froot you're using here?). This can be trivially parallelized, running separate chunks on separate cores (see e.g. answers to this question).

How badly do you need extendInt ? I can triple the speed if I make a hacked version of the uniroot() function with only its core functionality, none of the argument-testing logic etc. etc.. However, your speed gain will be much less impressive if your target function is much slower than the example you've given here; if that's the case, you should focus on speeding up your target function (I tried recoding your froot in C++ via Rcpp, but it doesn't really help in this case — the function is sufficiently trivial that the function-calling overhead takes most of the time ...)

I did this with only 5000 rows, for ease of benchmarking:

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

Minimal function:

uu <- function(f, lower, upper, tol = 1e-8, maxiter =1000L, ...) {
  f.lower <- f(lower, ...)
  f.upper <- f(upper, ...)
  val <- .External2(stats:::C_zeroin2, function(arg) f(arg, ...),
                    lower, upper, f.lower, f.upper, tol, as.integer(maxiter))
  return(val[1])
}

Check that we get the same results:

identical(uniroot(froot, u = 3.242, a=0.5, b=1, interval = c(0.01,100))$root,
          uu(froot, u = 3.242, a=0.5, b=1, lower = 0.01, upper = 100))
## TRUE

Benchmarking package; wrap evaluations in functions for compactness

library(rbenchmark)
f1 <- function() {
  dt[, c := uniroot(froot_cpp, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
}
f2 <- function() {
  dt[, c := uu(froot, u=u, a=a, b=b, lower = 0.01, upper = 100), by = u]
}
bb <- benchmark(f1(), f2(), 
    columns =c("test", "replications", "elapsed", "relative"))

Results:

  test replications elapsed relative
1 f1()          100  34.616    3.074
2 f2()          100  11.261    1.000
like image 131
Ben Bolker Avatar answered Oct 21 '22 07:10

Ben Bolker


Note that the inverse of the function shown can be computed explicitly as

f2 <- function(x) (b^a * x / b)^(1/a)
a <- 1/2
b <- 1
all.equal(f(.5), f2(.5))  # f defined below using uniroot
## [1] TRUE

however, assuming that in reality you have a more complex function we can use Chebyshev approximation to get a close approximation to it. Note that a and b are constants in the question and so we also assume that to be the case below, i.e. f uses the constants a and b set in the global environment. The code below runs nearly 100x faster than the code in the question on the benchmark with a 9th degree polynomial and is within 1e-4 of the answer given by uniroot. Use a higher degree if you need even more accuracy.

library(data.table)
library(pracma)
set.seed(123)

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)

dt2 <- copy(dt)
f <- function(u) {
  uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root
}

library(microbenchmark)
microbenchmark(times = 10,
  orig = dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u],
  cheb = dt2[, c := chebApprox(u, Vectorize(f), min(u), max(u), 9)]
)
## Unit: milliseconds
##  expr      min       lq      mean    median       uq      max neval cld
##  orig 943.5323 948.9321 961.00361 958.91970 972.6308 982.0060    10   b
##  cheb   9.3752   9.7513  10.67386  10.02555  10.3411  16.9475    10  a 

max(abs(dt$c - dt2$c))
## [1] 8.081021e-05
like image 3
G. Grothendieck Avatar answered Oct 21 '22 08:10

G. Grothendieck


There are great answers to the exact question, but a couple of notes on general R practices.

Using by when order doesn't matter

In the OP, we are using by = u so that each row is run one at a time. This is inefficient! data.table will order u, determine groupings, and since they are real very random numbers, end up with as many groupings as rows.

Instead, we can use Map() or mapply() to iterate through the rows which will improve performance. Note, it's unclear whether a and b actually vary by row - if they truly are constant, we would likely want to take them out of the data.table and pass them as constants.

uniroot2 = function(...) uniroot(...)$root ## helper function
dt[, c2 := mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes'))]

## for n = 5000

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
##  1 OP           1.17s   1.17s     0.851     170KB     2.55     1     3      1.17s
##  2 no_by      857.2ms 857.2ms     1.17      214KB     3.50     1     3    857.2ms
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

Note, once we have it set up in mapply, it is trivial to use future.apply::future_mapply() to parallelize our call. This is 2.5 times faster than the no_by example above on my laptop.

library(future.apply)
plan(multisession)
dt[, c3 := future_mapply(uniroot2, u, a,b,
                  MoreArgs = list (f = froot,
                                   interval = c(0.01, 10),
                                   extendInt = 'yes')
                  , future.globals = "cumhaz")] ## see next section for how we could remove this

Function calls take time

In your example, you define two functions as:

cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u

When performance is an issue and it is trivial to simplify, you may want to simplify.

froot2 = function(x, u, a, b) b * (x / b) ^ a - u

Over a million of loops, the additional call to cumhaz() adds up:

x = 2.5; u = 1.5; a = 0.5; b = 1 
bench::mark(froot_rep = for (i in 1:1e6) {froot(x=x, u=u, a=a, b=b)},
            froot2_rep = for (i in 1:1e6) {froot2(x=x, u=u, a=a, b=b)})

## # A tibble: 2 x 13
##   expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
##   <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
## 1 froot_rep    4.74s   4.74s     0.211    13.8KB     3.38     1    16      4.74s
## 2 froot2_rep   3.17s   3.17s     0.315    13.8KB     2.84     1     9      3.17s
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled. 

Since uniroot would further increase the calls with a default max iterations of 1,000! That means cumhaz() costs us somewhere between 1.5s and 1,500s during the optimization. And as @G. Grothendieck pointed out, sometimes we can actually directly solve and used direct vectorized methods instead of relying on uniroot or optimize.

like image 2
Cole Avatar answered Oct 21 '22 07:10

Cole