Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sum of most recent values across groups

Tags:

r

data.table

For each row of my data I'd like to compute the sum of most recent value for each group:

dt = data.table(group = c('a','b','a','a','b','a'),                 value = c(10, 5, 20, 15, 15, 10),                 desired = c(10, 15, 25, 20, 30, 25)) #   group value desired #1:     a    10      10 #2:     b     5      15 #3:     a    20      25  # latest value of a is 20, of b is 5 #4:     a    15      20  # latest value of a is 15, of b is 5 #5:     b    15      30 #6:     a    10      25 

desired column is what I want to achieve, and I can do this with a naive loop, but my data is quite large with a lot of rows and groups (1M+ rows, 1000+ groups).

for (i in seq_len(nrow(dt))) {   # can use `set` to make this faster, but still too slow   # this is just to illustrate *a* solution   dt[i, desired1 := dt[1:i, value[.N], by = group][, sum(V1)]] } 
like image 845
eddi Avatar asked Jun 09 '16 17:06

eddi


1 Answers

Even simpler logic from @eddi (under comments) reducing the roundabout one shown below:

dt[, incr := diff(c(0, value)), by = group][, ans := cumsum(incr)] 

Not sure how it extends to more groups, but here's on an example data with 3 groups:

# I hope I got the desired output correctly require(data.table) dt = data.table(group = c('a','b','c','a','a','b','c','a'),                 value = c(10, 5, 20, 25, 15, 15, 30, 10),                 desired = c(10, 15, 35, 50, 40, 50, 60, 55)) 

Add an rleid:

dt[, id := rleid(group)] 

Extract the last row for each group, id:

last = dt[, .(value=value[.N]), by=.(group, id)] 

last will have unique id. Now the idea is to get the increment for each id, and then join+update back.

last = last[, incr := value - shift(value, type="lag", fill=0L), by=group           ][, incr := cumsum(incr)-value][] 

Join + update now:

dt[last, ans := value + i.incr, on="id"][, id := NULL][] #    group value desired ans # 1:     a    10      10  10 # 2:     b     5      15  15 # 3:     c    20      35  35 # 4:     a    25      50  50 # 5:     a    15      40  40 # 6:     b    15      50  50 # 7:     c    30      60  60 # 8:     a    10      55  55 

I'm not yet sure where/if this breaks.. will look at it carefully now. I wrote it immediately so that there are more eyes on it.


Comparing on 500 groups with 10,000 rows with David's solution:

require(data.table) set.seed(45L) groups = apply(matrix(sample(letters, 500L*10L, TRUE), ncol=10L), 1L, paste, collapse="") uniqueN(groups) # 500L N = 1e4L dt = data.table(group=sample(groups, N, TRUE), value = sample(100L, N, TRUE))  arun <- function(dt) {      dt[, id := rleid(group)]     last = dt[, .(value=value[.N]), by=.(group, id)]     last = last[, incr := value - shift(value, type="lag", fill=0L), by=group               ][, incr := cumsum(incr)-value][]     dt[last, ans := value + i.incr, on="id"][, id := NULL][]     dt$ans }  david <- function(dt) {     dt[, indx := .I]     res <- dcast(dt, indx ~ group)     for (j in names(res)[-1L])          set(res, j = j, value = res[!is.na(res[[j]])][res, on = "indx", roll = TRUE][[j]])     rowSums(as.matrix(res)[, -1], na.rm = TRUE)  }  system.time(ans1 <- arun(dt))  ## 0.024s system.time(ans2 <- david(dt)) ## 38.97s  identical(ans1, as.integer(ans2)) # [1] TRUE 
like image 88
Arun Avatar answered Oct 11 '22 05:10

Arun