I am running stats::uniroot
function on one million rows in data.table. Here is a toy example -
library(data.table)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u
n <- 50000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)
print(system.time(
dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
))
On the above code, the time taken is close to 8 seconds for 50,000 rows.
Is there any faster alternative to the uniroot
function which can reduce this time considerably?
160 seconds (1e6/5e4 * 8) doesn't sound so bad to me for a million rows (although maybe your real function is much slower than the froot
you're using here?). This can be trivially parallelized, running separate chunks on separate cores (see e.g. answers to this question).
How badly do you need extendInt
? I can triple the speed if I make a hacked version of the uniroot()
function with only its core functionality, none of the argument-testing logic etc. etc.. However, your speed gain will be much less impressive if your target function is much slower than the example you've given here; if that's the case, you should focus on speeding up your target function (I tried recoding your froot
in C++ via Rcpp
, but it doesn't really help in this case — the function is sufficiently trivial that the function-calling overhead takes most of the time ...)
I did this with only 5000 rows, for ease of benchmarking:
n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)
Minimal function:
uu <- function(f, lower, upper, tol = 1e-8, maxiter =1000L, ...) {
f.lower <- f(lower, ...)
f.upper <- f(upper, ...)
val <- .External2(stats:::C_zeroin2, function(arg) f(arg, ...),
lower, upper, f.lower, f.upper, tol, as.integer(maxiter))
return(val[1])
}
Check that we get the same results:
identical(uniroot(froot, u = 3.242, a=0.5, b=1, interval = c(0.01,100))$root,
uu(froot, u = 3.242, a=0.5, b=1, lower = 0.01, upper = 100))
## TRUE
Benchmarking package; wrap evaluations in functions for compactness
library(rbenchmark)
f1 <- function() {
dt[, c := uniroot(froot_cpp, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u]
}
f2 <- function() {
dt[, c := uu(froot, u=u, a=a, b=b, lower = 0.01, upper = 100), by = u]
}
bb <- benchmark(f1(), f2(),
columns =c("test", "replications", "elapsed", "relative"))
Results:
test replications elapsed relative
1 f1() 100 34.616 3.074
2 f2() 100 11.261 1.000
Note that the inverse of the function shown can be computed explicitly as
f2 <- function(x) (b^a * x / b)^(1/a)
a <- 1/2
b <- 1
all.equal(f(.5), f2(.5)) # f defined below using uniroot
## [1] TRUE
however, assuming that in reality you have a more complex function we can use Chebyshev approximation to get a close approximation to it. Note that a and b are constants in the question and so we also assume that to be the case below, i.e. f uses the constants a and b set in the global environment. The code below runs nearly 100x faster than the code in the question on the benchmark with a 9th degree polynomial and is within 1e-4 of the answer given by uniroot. Use a higher degree if you need even more accuracy.
library(data.table)
library(pracma)
set.seed(123)
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u
n <- 5000
u <- -log(runif(n))
a <- 1/2
b <- 1
dt = data.table(u = u, a = a, b = b)
dt2 <- copy(dt)
f <- function(u) {
uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root
}
library(microbenchmark)
microbenchmark(times = 10,
orig = dt[, c := uniroot(froot, u=u, a=a, b=b, interval= c(0.01, 10), extendInt="yes")$root, by = u],
cheb = dt2[, c := chebApprox(u, Vectorize(f), min(u), max(u), 9)]
)
## Unit: milliseconds
## expr min lq mean median uq max neval cld
## orig 943.5323 948.9321 961.00361 958.91970 972.6308 982.0060 10 b
## cheb 9.3752 9.7513 10.67386 10.02555 10.3411 16.9475 10 a
max(abs(dt$c - dt2$c))
## [1] 8.081021e-05
There are great answers to the exact question, but a couple of notes on general R practices.
In the OP, we are using by = u
so that each row is run one at a time. This is inefficient! data.table
will order u, determine groupings, and since they are real very random numbers, end up with as many groupings as rows.
Instead, we can use Map()
or mapply()
to iterate through the rows which will improve performance. Note, it's unclear whether a
and b
actually vary by row - if they truly are constant, we would likely want to take them out of the data.table and pass them as constants.
uniroot2 = function(...) uniroot(...)$root ## helper function
dt[, c2 := mapply(uniroot2, u, a,b,
MoreArgs = list (f = froot,
interval = c(0.01, 10),
extendInt = 'yes'))]
## for n = 5000
## # A tibble: 2 x 13
## expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
## <bch:expr> <bch:t> <bch:t> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm>
## 1 OP 1.17s 1.17s 0.851 170KB 2.55 1 3 1.17s
## 2 no_by 857.2ms 857.2ms 1.17 214KB 3.50 1 3 857.2ms
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled.
Note, once we have it set up in mapply
, it is trivial to use future.apply::future_mapply()
to parallelize our call. This is 2.5 times faster than the no_by
example above on my laptop.
library(future.apply)
plan(multisession)
dt[, c3 := future_mapply(uniroot2, u, a,b,
MoreArgs = list (f = froot,
interval = c(0.01, 10),
extendInt = 'yes')
, future.globals = "cumhaz")] ## see next section for how we could remove this
In your example, you define two functions as:
cumhaz <- function(t, a, b) b * (t/b)^a
froot <- function(x, u, a, b) cumhaz(x, a, b) - u
When performance is an issue and it is trivial to simplify, you may want to simplify.
froot2 = function(x, u, a, b) b * (x / b) ^ a - u
Over a million of loops, the additional call to cumhaz()
adds up:
x = 2.5; u = 1.5; a = 0.5; b = 1
bench::mark(froot_rep = for (i in 1:1e6) {froot(x=x, u=u, a=a, b=b)},
froot2_rep = for (i in 1:1e6) {froot2(x=x, u=u, a=a, b=b)})
## # A tibble: 2 x 13
## expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
## <bch:expr> <bch:t> <bch:t> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm>
## 1 froot_rep 4.74s 4.74s 0.211 13.8KB 3.38 1 16 4.74s
## 2 froot2_rep 3.17s 3.17s 0.315 13.8KB 2.84 1 9 3.17s
##
## Warning message:
## Some expressions had a GC in every iteration; so filtering is disabled.
Since uniroot
would further increase the calls with a default max iterations of 1,000! That means cumhaz()
costs us somewhere between 1.5s and 1,500s during the optimization. And as @G. Grothendieck pointed out, sometimes we can actually directly solve and used direct vectorized methods instead of relying on uniroot
or optimize
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With