Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

data.table | faster row-wise recursive update within group

I have to do the following recursive row-by-row operation to obtain z:

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t >= 2) { a[t] = b[t-1] + xb[t] }
    z[t] = rnorm(1, mean = a[t])
    b[t] = a[t] + z[t]

}

return(z)

}

set.seed(1)

n_smpl = 1e6 
ni = 5

id = rep(1:n_smpl, each = ni)

smpl = data.table(id)
smpl[, time := 1:.N, by = id]

a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

smpl[, z := myfun(xb, a, b), by = id]

I would like to obtain a result like this:

      id time a b  xb            z
  1:   1    1 1 1   1    0.3735462
  2:   1    2 1 1   2    2.7470924
  3:   1    3 1 1   3    8.4941848
  4:   1    4 1 1   4   20.9883695
  5:   1    5 1 1   5   46.9767390
 ---                              
496: 100    1 1 1 100    0.3735462
497: 100    2 1 1 200  200.7470924
498: 100    3 1 1 300  701.4941848
499: 100    4 1 1 400 1802.9883695
500: 100    5 1 1 500 4105.9767390

This does work but takes time:

system.time(smpl[, z := myfun(xb, a, b), by = id])
   user  system elapsed 
 33.646   0.994  34.473

I need to make it faster, given the size of my actual data (over 2 million observations). I guess do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b') with by = .(id, time) would be much faster, avoiding the for loop inside myfun. However, I was not sure how I can update b and its lag (probably using shift) when I run this row-by-row operation in data.table. Any suggestions?

like image 739
jayc Avatar asked Jan 27 '17 03:01

jayc


1 Answers

Great question!

Starting from a fresh R session, showing the demo data with 5 million rows, here's your function from the question and the timing on my laptop. With some comments inline.

require(data.table)   # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

myfun = function (xb, a, b) {

  z = NULL
  # initializes a new length-0 variable

  for (t in 1:length(xb)) {

      if (t >= 2) { a[t] = b[t-1] + xb[t] }
      # if() on every iteration. t==1 could be done before loop

      z[t] = rnorm(1, mean = a[t])
      # z vector is grown by 1 item, each time

      b[t] = a[t] + z[t]
      # assigns to all of b vector when only really b[t-1] is
      # needed on the next iteration 
  }
  return(z)
}

set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
   user  system elapsed 
 19.216   0.004  19.212

smpl
              id time a b      xb            z
      1:       1    1 1 1       1 3.735462e-01
      2:       1    2 1 1       2 3.557190e+00
      3:       1    3 1 1       3 9.095107e+00
      4:       1    4 1 1       4 2.462112e+01
      5:       1    5 1 1       5 5.297647e+01
     ---                                      
4999996: 1000000    1 1 1 1000000 1.618913e+00
4999997: 1000000    2 1 1 2000000 2.000000e+06
4999998: 1000000    3 1 1 3000000 7.000003e+06
4999999: 1000000    4 1 1 4000000 1.800001e+07
5000000: 1000000    5 1 1 5000000 4.100001e+07

So 19.2s is the time to beat. In all these timings, I've run the command 3 times locally to make sure it's a stable timing. The timing variance is insignificant in this task so I'll just report one timing to keep the answer quicker to read.

Tackling the comments inline above in myfun() :

myfun2 = function (xb, a, b) {

  z = numeric(length(xb))
  # allocate size up front rather than growing

  z[1] = rnorm(1, mean=a[1])
  prevb = a[1]+z[1]
  t = 2L
  while(t<=length(xb)) {
    at = prevb + xb[t]
    z[t] = rnorm(1, mean=at)
    prevb = at + z[t]
    t = t+1L
  }
  return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
   user  system elapsed 
 13.212   0.036  13.245 
smpl[,identical(z,z2)]
[1] TRUE

That was quite good (19.2s down to 13.2s) but it's still a for loop at R level. On first glance it can't be vectorized because the rnorm() call depends on the previous value. In fact, it probably can be vectorized by using the property that m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) and calling vectorized rnorm(n=5e6) once rather than 5e6 times. But there'd probably be a cumsum() involved to deal with the groups. So let's not go there as that would probably make the code harder to read and would be specific to this precise problem.

So let's try Rcpp which looks very similar to the style you wrote and is more widely applicable :

require(Rcpp)   # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
  NumericVector z = NumericVector(xb.length());
  z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
  double prevb = a[0]+z[0];
  int t = 1;
  while (t<xb.length()) {
    double at = prevb + xb[t];
    z[t] = R::rnorm(at, 1);
    prevb = at + z[t];
    t++;
  }
  return z;
}')

set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
   user  system elapsed 
  1.800   0.020   1.819 
smpl[,identical(z,z3)]
[1] TRUE

Much better: 19.2s down to 1.8s. But every call to the function calls the first line (NumericVector()) which allocates a new vector as long as the number of rows in the group. That's then filled in and returned which is copied to the final column in the correct place for that group (by :=), only to be released. That allocation and management of all those 1 million small temporary vectors (one for each group) is all a bit convoluted.

Why don't we do the whole column in one go? You've already written it in a for loop style and there's nothing wrong with that. Let's tweak the C function to accept the id column too and add the if for when it reaches a new group.

cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {

  // ** id must be pre-grouped, such as via setkey(DT,id) **

  NumericVector z = NumericVector(id.length());
  int previd = id[0]-1;  // initialize to anything different than id[0]
  for (int i=0; i<id.length(); i++) {
    double prevb;
    if (id[i]!=previd) {
      // first row of new group
      z[i] = R::rnorm(a[i], 1);
      prevb = a[i]+z[i];
      previd = id[i];
    } else {
      // 2nd row of group onwards
      double at = prevb + xb[i];
      z[i] = R::rnorm(at, 1);
      prevb = at + z[i];
    }
  }
  return z;
}')

system.time(setkey(smpl,id))  # ensure grouped by id
   user  system elapsed
  0.028   0.004   0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
   user  system elapsed 
  0.232   0.004   0.237 
smpl[,identical(z,z4)]
[1] TRUE

That's better: 19.2s down to 0.27s.

like image 132
Matt Dowle Avatar answered Nov 04 '22 10:11

Matt Dowle