Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R grouping by condition in data.table

In R, I have a large data.table. For every row, I want to count rows with a similar value of x1 (+/- some tolerance, tol). I can get this to work using adply, but it's too slow. It seems like the sort of thing data.table would be good for - in fact, I'm already using data.table for part of the computation.

Is there a way to do this entirely with data.table? Here is an example:

library(data.table)
library(plyr)
my.df = data.table(x1 = 1:1000,
                   x2 = 4:1003)
tol = 3
adply(my.df, 1, function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])

Results:

        x1   x2 V1
   1:    1    4  3
   2:    2    5  4
   3:    3    6  5
   4:    4    7  5
   5:    5    8  5
  ---             
 996:  996  999  5
 997:  997 1000  5
 998:  998 1001  5
 999:  999 1002  4
1000: 1000 1003  3

Update:

Here's a sample dataset that is a little closer to my real data:

set.seed(10)
x = seq(1,100000000,100000)
x = x + sample(1:50000, length(x), replace=T)
x2 = x + sample(1:50000, length(x), replace=T)
my.df = data.table(x1 = x,
                   x2 = x2)
setkey(my.df,x1)
tol = 100000

og = function(my.df) {
  adply(my.df, 1, function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])
}

microbenchmark(r_ed <- ed(copy(my.df)),
               r_ar <- ar(copy(my.df)),
               r_og <- og(copy(my.df)),
               times = 1)

Unit: milliseconds
                    expr         min          lq      median          uq         max neval
 r_ed <- ed(copy(my.df))    8.553137    8.553137    8.553137    8.553137    8.553137     1
 r_ar <- ar(copy(my.df))   10.229438   10.229438   10.229438   10.229438   10.229438     1
 r_og <- og(copy(my.df)) 1424.472844 1424.472844 1424.472844 1424.472844 1424.472844     1

Obviously, solutions from both @eddi and @Arun are much faster than mine. Now I just have to try to understand rolls.

like image 461
benjamin Avatar asked Aug 08 '13 06:08

benjamin


People also ask

What is sd in R data table?

SD is a single sub- data. table ). This allows us to concisely express an operation that we'd like to perform on each sub- data. table before the re-assembled result is returned to us.

What does. sd mean in data table?

SD stands for "Subset of Data. table". The dot before SD has no significance but doesn't let it clash with a user-defined column name.

How do I add data to a table in R?

To add or insert observation/row to an existing Data Frame in R, we use rbind() function. We can add single or multiple observations/rows to a Data Frame in R using rbind() function.


3 Answers

See @eddi's answer for a faster solution (to this particular problem). It also works when x1 is not an integer.

The algorithm you're looking for is Interval Tree. And there's a bioconductor package called IRanges that accomplishes this task. It's hard to beat that.

require(IRanges)
require(data.table)
my.df[, res := countOverlaps(IRanges(my.df$x1, width=1), 
           IRanges(my.df$x1-tol+1, my.df$x1+tol-1))]

Some explanation:

If you break down the code, you can write it in three lines:

ir1 <- IRanges(my.df$x1, width=1)
ir2 <- IRanges(my.df$x1-tol+1, my.df$x1+tol-1)
cnt <- countOverlaps(ir1, ir2)

What we essentially do is to is to create two "ranges" (just type ir1 and ir2 to see how they are). Then we ask, for each entry in ir1 how many do they overlap in ir2 (this is the "interval tree" part). And this is very efficient. Implicitly the argument type to countOverlaps, by default is "type = any". You can explore the other types if you want. It's extremely useful. Also of relevance is findOverlaps function.

Note: There can be faster solutions (in fact there is, see @eddi's) for this particular case, where width of ir1 = 1. But for problems where widths are variable and/or > 1, this should be the fastest.


Benchmarking:

ag <- function(my.df) my.df[, res := sum(abs(my.df$x1-x1) < tol), by=x1]
ro <- function(my.df) {
            my.df[,res:= { y = my.df$x1
            sum(y > (x1 - tol) & y < (x1 + tol))
            }, by=x1]
      }
ar <- function(my.df) {
           my.df[, res := countOverlaps(IRanges(my.df$x1, width=1), 
            IRanges(my.df$x1-tol+1, my.df$x1+tol-1))]
      }


require(microbenchmark)
microbenchmark(r1 <- ag(copy(my.df)), r2 <- ro(copy(my.df)), 
               r3 <- ar(copy(my.df)), times=100)

Unit: milliseconds
                  expr      min       lq   median       uq       max neval
 r1 <- ag(copy(my.df)) 33.15940 39.63531 41.61555 44.56616 208.99067   100
 r2 <- ro(copy(my.df)) 69.35311 76.66642 80.23917 84.67419 344.82031   100
 r3 <- ar(copy(my.df)) 11.22027 12.14113 13.21196 14.72830  48.61417   100 <~~~

identical(r1, r2) # TRUE
identical(r1, r3) # TRUE
like image 125
Arun Avatar answered Oct 12 '22 10:10

Arun


Here's a faster data.table solution. The idea is to use the rolling merge functionality of data.table, but before we do that we need to modify the data slightly and make the column x1 numeric instead of integer. This is because OP is using strict inequality and to use rolling joins with that we're going to have to decrease the tolerance by a tiny amount, making it a floating point number.

my.df[, x1 := as.numeric(x1)]

# set the key to x1 for the merges and to sort
# (note, if data already sorted can make this step instantaneous using setattr)
setkey(my.df, x1)

# and now we're going to do two rolling merges, one with the upper bound
# and one with lower, then get the index of the match and subtract the ends
# (+1, to get the count)
my.df[, res := my.df[J(x1 + tol - 1e-6), list(ind = .I), roll = Inf]$ind -
               my.df[J(x1 - tol + 1e-6), list(ind = .I), roll = -Inf]$ind + 1]


# and here's the bench vs @Arun's solution
ed = function(my.df) {
  my.df[, x1 := as.numeric(x1)]
  setkey(my.df, x1)
  my.df[, res := my.df[J(x1 + tol - 1e-6), list(ind = .I), roll = Inf]$ind -
                 my.df[J(x1 - tol + 1e-6), list(ind = .I), roll = -Inf]$ind + 1]
}

microbenchmark(ed(copy(my.df)), ar(copy(my.df)))
#Unit: milliseconds
#            expr       min       lq   median       uq      max neval
# ed(copy(my.df))  7.297928 10.09947 10.87561 11.80083 23.05907   100
# ar(copy(my.df)) 10.825521 15.38151 16.36115 18.15350 21.98761   100

Note: as both Arun and Matthew pointed out, if x1 is integer, one doesn't have to convert to numeric and subtract a small amount from tol and can use tol - 1L instead of tol - 1e-6 above.

like image 36
eddi Avatar answered Oct 12 '22 09:10

eddi


Here is a pure data.table solution:

my.df[, res:=sum(my.df$x1 > (x1 - tol) & my.df$x1 < (x1 + tol)), by=x1]

my.df <- adply(my.df, 1, 
           function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])

identical(my.df[,res],my.df[,V1])
#[1] TRUE

However, this will still be relatively slow if you have many unique x1. After all, you need to do a huge number of comparisons and I can't think of a way to avoid that right now.

like image 28
Roland Avatar answered Oct 12 '22 08:10

Roland