Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R data.table sum of group subset using dates

Tags:

r

data.table

I have a dataset like the following:

library(data.table)    
dt1 <- data.table(urn = c(rep("a", 5), rep("b", 4)),
                  amount = c(10, 12, 23, 15, 19, 42, 11, 5, 10),
                  date = as.Date(c("2016-01-01", "2017-01-02", "2017-02-04",
                                   "2017-04-19", "2018-02-11", "2016-02-14",
                                   "2017-05-06", "2017-05-12", "2017-12-12")))
dt1
#    urn amount       date
# 1:   a     10 2016-01-01
# 2:   a     12 2017-01-02
# 3:   a     23 2017-02-04
# 4:   a     15 2017-04-19
# 5:   a     19 2018-02-11
# 6:   b     42 2016-02-14
# 7:   b     11 2017-05-06
# 8:   b      5 2017-05-12
# 9:   b     10 2017-12-12

I am trying to determine the cumulative value for a group over the preceding 12 months. I know I can use shift with data.table to scan backwards or forwards, the biggest challenge I can't get my head around is how to know how many records to sum when the number can change based on how many records each urn has.

The type of results I am looking for are:

dt1
#    urn amount       date summed12m
# 1:   a     10 2016-01-01        10
# 2:   a     12 2017-01-02        12
# 3:   a     23 2017-02-04        35
# 4:   a     15 2017-04-19        50
# 5:   a     19 2018-02-11        34
# 6:   b     42 2016-02-14        42
# 7:   b     11 2017-05-06        11
# 8:   b      5 2017-05-12        16
# 9:   b     10 2017-12-12        26   

I'm preferably looking for a data.table solution due to the volume of my data, but am open to other options too if it is likely to be efficient over a table with about 12M records.

like image 728
Dan Avatar asked Feb 14 '18 04:02

Dan


1 Answers

As alternative to foverlaps(), this also can be solved by aggregating in a non-equi join:

library(lubridate)
dt1[, summed12m := dt1[.(urn, date, date %m-% months(12)), 
                       on = .(urn = V1, date <= V2, date >= V3), 
                       sum(amount), by = .EACHI]$V1][]
   urn amount       date summed12m
1:   a     10 2016-01-01        10
2:   a     12 2017-01-02        12
3:   a     23 2017-02-04        35
4:   a     15 2017-04-19        50
5:   a     19 2018-02-11        34
6:   b     42 2016-02-14        42
7:   b     11 2017-05-06        11
8:   b      5 2017-05-12        16
9:   b     10 2017-12-12        26

lubridate is used for date arithmetic to avoid mishaps in case one of the dates is February, 29.

The essential part is the non-equi join

dt1[.(urn, date, date %m-% months(12)), 
    on = .(urn = V1, date <= V2, date >= V3), 
    sum(amount), by = .EACHI]
   urn       date       date V1
1:   a 2016-01-01 2015-01-01 10
2:   a 2017-01-02 2016-01-02 12
3:   a 2017-02-04 2016-02-04 35
4:   a 2017-04-19 2016-04-19 50
5:   a 2018-02-11 2017-02-11 34
6:   b 2016-02-14 2015-02-14 42
7:   b 2017-05-06 2016-05-06 11
8:   b 2017-05-12 2016-05-12 16
9:   b 2017-12-12 2016-12-12 26

of which the last column is picked to create the new summed12m column in dt1.

Additional explanation

The OP has asked where V1, V2, and V3 come from.

The expression .(urn, date, date %m-% months(12)) creates a new data.table on the fly. (.() is an data.table abbreviation for list()). As no column names have been specified, data.table creates default column names V1, V2, etc.

Less sloppily, the expression can be re-written with explicitely named columns

dt1[.(urn = urn, end = date, start = date %m-% months(12)), 
    on = .(urn, date <= end, date >= start), 
    sum(amount), by = .EACHI]
like image 169
Uwe Avatar answered Oct 18 '22 11:10

Uwe