Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient implementation in computing pairwise differences

Tags:

r

Suppose I have a data frame as follows:

> foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
> foo
  x id
1 1  1
2 2  1
3 3  2
4 4  2
5 5  2
6 6  3
7 7  3
8 8  3
9 9  3

I want a very efficient implementation of h(a, b) that computes sums all (a - xi)*(b - xj) for xi, xj belonging to the same id class. For example, my current implementation is

h(a, b, foo){
  a.diff = a - foo$x
  b.diff = b - foo$x
  prod = a.diff%*%t(b.diff)
  id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
  return(sum(prod*id.indicator))
}

For example, with (a, b) = (0, 1), here is the output from each step in the function

> a.diff
[1] -1 -2 -3 -4 -5 -6 -7 -8 -9
> b.diff
[1]  0 -1 -2 -3 -4 -5 -6 -7 -8
> prod
      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
 [1,]    0    1    2    3    4    5    6    7    8
 [2,]    0    2    4    6    8   10   12   14   16
 [3,]    0    3    6    9   12   15   18   21   24
 [4,]    0    4    8   12   16   20   24   28   32
 [5,]    0    5   10   15   20   25   30   35   40
 [6,]    0    6   12   18   24   30   36   42   48
 [7,]    0    7   14   21   28   35   42   49   56
 [8,]    0    8   16   24   32   40   48   56   64
 [9,]    0    9   18   27   36   45   54   63   72
> id.indicator
  1 2 3 4 5 6 7 8 9
1 1 1 0 0 0 0 0 0 0
2 1 1 0 0 0 0 0 0 0
3 0 0 1 1 1 0 0 0 0
4 0 0 1 1 1 0 0 0 0
5 0 0 1 1 1 0 0 0 0
6 0 0 0 0 0 1 1 1 1
7 0 0 0 0 0 1 1 1 1
8 0 0 0 0 0 1 1 1 1
9 0 0 0 0 0 1 1 1 1

In reality, there can be up to 1000 id clusters, and each cluster will be at least 40, making this method too inefficient because of the sparse entries in id.indicator and extra computations in prod on the off-block-diagonals which won't be used.

like image 592
Tom Chen Avatar asked Feb 27 '17 20:02

Tom Chen


1 Answers

I played a round a bit. First, your implementation:

foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))

h <- function(a, b, foo){
  a.diff = a - foo$x
  b.diff = b - foo$x
  prod = a.diff%*%t(b.diff)
  id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + 
     diag(nrow(foo))
  return(sum(prod*id.indicator))
}

h(a = 1, b = 0, foo = foo)
#[1] 891

Next, I tried a variant using a proper sparse matrix implementation (via the Matrix package) and functions for the index matrix. I also use tcrossprod which I often find to be a bit faster than a %*% t(b).

library("Matrix")

h2 <- function(a, b, foo) {
  a.diff <- a - foo$x
  b.diff <- b - foo$x
  prod <- tcrossprod(a.diff, b.diff) # the same as a.diff%*%t(b.diff)
  id.indicator <- do.call(bdiag, lapply(table(foo$id), function(n) matrix(1,n,n)))
  return(sum(prod*id.indicator))
}

h2(a = 1, b = 0, foo = foo)
#[1] 891

Note that this function relies on foo$id being sorted.

Lastly, I tried avoid creating the full n by n matrix.

h3 <- function(a, b, foo) {
  a.diff <- a - foo$x
  b.diff <- b - foo$x
  ids <- unique(foo$id)
  res <- 0
  for (i in seq_along(ids)) {
    indx <- which(foo$id == ids[i])
    res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
  }
  return(res)
}

h3(a = 1, b = 0, foo = foo)
#[1] 891

Benchmarking on your example:

library("microbenchmark")
microbenchmark(h(a = 1, b = 0, foo = foo), 
               h2(a = 1, b = 0, foo = foo),
               h3(a = 1, b = 0, foo = foo))
# Unit: microseconds
#                        expr      min        lq      mean    median        uq       max neval
#  h(a = 1, b = 0, foo = foo)  248.569  261.9530  493.2326  279.3530  298.2825 21267.890   100
# h2(a = 1, b = 0, foo = foo) 4793.546 4893.3550 5244.7925 5051.2915 5386.2855  8375.607   100
# h3(a = 1, b = 0, foo = foo)  213.386  227.1535  243.1576  234.6105  248.3775   334.612   100

Now, in this example, the h3 is the fastest and h2 is really slow. But I guess that both will be faster for larger examples. Probably, h3 will still win for larger examples though. While there is plenty of room of more optimization, h3 should be faster and more memory efficient. So, I think you should go for a variant of h3 which does not create unnecessarily large matrices.

like image 53
Anders Ellern Bilgrau Avatar answered Oct 24 '22 03:10

Anders Ellern Bilgrau