Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Faster coding than using for loop

Tags:

r

dplyr

Suppose I have the following data frame

set.seed(36)    

n <- 300

dat <- data.frame(x = round(runif(n,0,200)), y =  round(runif(n, 0, 500)))
d <- dat[order(dat$y),]

For each value of d$y<=300, I have to create a variable res in which the numerator is the sum of the indicator (d$x <= d$y[i]) and the denominator is the sum of the indicator (d$y >= d$y[i]). I have written the codes in for loop:

res <- NULL

for( i in seq_len(sum(d$y<=300)) ){

    numerator   <- sum(d$x <= d$y[i])
    denominator <- sum(d$y >= d$y[i])

    res[i] <- numerator / denominator
 } 

But my concern is when the number of observations of x and y is large, that is, the number of rows of the data frame increases, the for loop will work slowly. Additionally, if I simulate data 1000 times and each time run the for loop, the program will be inefficient.

What can be the more efficient solution of the code?

like image 686
user 31466 Avatar asked Feb 24 '17 03:02

user 31466


2 Answers

This depends on d already being sorted as it is:

# example data
set.seed(36)    
n <- 1e5
dat <- data.frame(x = round(runif(n,0,200)), y =  round(runif(n, 0, 500)))
d <- dat[order(dat$y),]

My suggestion (thanks to @alexis_laz for the denominator):

system.time(res3 <- {
  xs <- sort(d$x)                                 # sorted x
  yt <- d$y[d$y <= 300]                           # truncated y

  num = findInterval(yt, xs)
  den = length(d$y) - match(yt, d$y) + 1L

  num/den
})
#    user  system elapsed 
#       0       0       0 

OP's approach:

system.time(res <- {
  res <- NULL
  for( i in seq_len(sum(d$y<=300)) ){
    numerator   <- sum(d$x <= d$y[i])
    denominator <- sum(d$y >= d$y[i])
    res[i] <- numerator / denominator
  }
  res
})
#    user  system elapsed 
#   50.77    1.13   52.10 

# verify it matched
all.equal(res,res3) # TRUE

@d.b's approach:

system.time(res2 <- {
  numerator = rowSums(outer(d$y, d$x, ">="))
  denominator = rowSums(outer(d$y, d$y, "<="))
  res2 = numerator/denominator

  res2 = res2[d$y <= 300]
  res2
})
# Error: cannot allocate vector of size 74.5 Gb

# ^ This error is common when using outer() on large-ish problems

Vectorization. Generally, tasks are faster in R if they can be vectorized. The key functions related to ordered vectors have confusing names (findInterval, sort, order and cut), but fortunately they all work on vectors.

Continuous vs discrete. The match above should be a fast way to compute the denominator whether the data is continuous or has mass points / repeating values. If the data is continuous (and so has no repeats), the denominator can just be seq(length(xs), length = length(yt), by=-1). If it is fully discrete and has a lot of repetition (like the example here), there might be some way to make that faster as well, maybe like one of these:

 den2 <- inverse.rle(with(rle(yt), list(
    values = length(xs) - length(yt) +  rev(cumsum(rev(lengths))), 
    lengths = lengths)))

 tab  <- unname(table(yt))
 den3 <- rep(rev(cumsum(rev(tab))) + length(xs) - length(yt), tab)

 # verify
 all.equal(den,den2) # TRUE
 all.equal(den,den3) # TRUE

findInterval will still work for the numerator for continuous data. It's not ideal for the repeated-values case considered here I guess (since we're redundantly finding the interval for many repeated yt values). Similar ideas for speeding that up likely apply.

Other options. As @chinsoon suggested, the data.table package might be a good fit if findInterval is too slow, since it has a lot of features focused on sorted data, but it's not obvious to me how to apply it here.

like image 90
Frank Avatar answered Sep 30 '22 20:09

Frank


Instead of running loop, generate all the numerator and denominator at once. This also allows you to keep track of which res is associated with which x and y. Later, you can keep only the ones you want.

You can use outer for element wise comparison between vectors.

numerator = rowSums(outer(d$y, d$x, ">=")) #Compare all y against all x
denominator = rowSums(outer(d$y, d$y, "<=")) #Compare all y against itself
res2 = numerator/denominator #Obtain 'res' for all rows
#I would first 'cbind' res2 to d and only then remove the ones for 'y <=300'

res2 = res2[d$y <= 300] #Keep only those 'res' that you want

Since this is using rowSums, this should be faster.

like image 22
d.b Avatar answered Sep 30 '22 21:09

d.b