Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

translate this loop into purr?

Tags:

r

purrr

tidyverse

I'm trying to make a teaching example on sampling for students to run, but the result is too slow when the number of iterations gets in the thousands (the real data frame df has several million rows).

Can I speed this up with purr?

library(tidyverse)
set.seed(1432)
df <- data.frame(v1 = sample(1:10, 100, replace=TRUE),
                 v2 = c(rep("A", 50), rep("B", 50))
)

output <- NULL

for (i in 1:10) {
  set.seed(i)
  d <- df %>%
    filter(v2=="A") %>%
    sample_n(20, replace=FALSE)

  mean = mean(d$v1)
  output <- c(output, mean)
}

output
like image 903
Eric Green Avatar asked Oct 29 '19 17:10

Eric Green


Video Answer


2 Answers

You can use purrr as follows.

map_dbl(1:10, function(x){
  set.seed(x)
  d <- df %>%
    filter(v2=="A") %>%
    sample_n(20, replace=FALSE)

  return(mean(d$v1))
})
# [1] 5.15 5.90 5.70 5.55 5.60 4.95 5.40 5.40 5.65 5.40
like image 146
www Avatar answered Oct 10 '22 23:10

www


purrr is not necessarily faster, but is more readable than basic control structures in R. When it comes to replacing the loop, here is what you can do in base R:

sapply(1:10, function(x){
  set.seed(x)
  d <- df %>%
    filter(v2=="A") %>%
    sample_n(20, replace=FALSE)
  mean(d$v1)
})

UPDATE That you use dplyr and purrr does not guarantee that your code is going to be fast. IMO, these packages were developed to improve code readability in the first place rather than to speed up expensive calculations. You can achieve a significant speed up if you carefully use basic R data structures. d is the original loop, a and b are functional programming solutions, and f is the optimized solution:

a <- function(y){sapply(1:y, function(x){
  set.seed(x)
  d <- df %>%
    filter(v2=="A") %>%
    sample_n(20, replace=FALSE)
    mean(d$v1)
})}

b <- function(y) {map_dbl(1:y, function(x){
  set.seed(x)
  d <- df %>%
    filter(v2=="A") %>%
    sample_n(20, replace=FALSE)

  return(mean(d$v1))
})}

d <- function(y){
  output <- NULL
  for (i in 1:y) {
    set.seed(i)
    d <- df %>%
      filter(v2=="A") %>%
      sample_n(20, replace=FALSE)
    output <- c(output, mean(d$v1))
  }

  output
}

f <- function(y){
  output <- vector("list", y)
  for (i in 1:y) {
    set.seed(i)
    d <- df[df$v2 == "A", ]
    d <- d[sample(1:nrow(d), 20, replace = FALSE), ]

    output[[i]] <- mean(d$v1)
  }

  output
}

microbenchmark::microbenchmark(a(100),b(100),d(100), f(100))

Unit: milliseconds
   expr       min        lq      mean    median        uq       max neval
 a(100) 172.06305 187.95053 205.19531 199.84411 210.55501 306.41906   100
 b(100) 171.86030 186.18869 206.50518 196.07746 213.79044 397.87859   100
 d(100) 174.45273 191.01706 208.07125 199.12653 216.54543 365.55107   100
 f(100)  14.62159  15.80092  20.96736  19.14848  24.16181  37.54095   100

Observe that f is almost 10x faster that d, while a, b, and d have almost the same speed.

like image 39
slava-kohut Avatar answered Oct 10 '22 23:10

slava-kohut