I need to sample an outcome variable given a matrix with row-wise outcome probabilities.
set.seed(1010) #reproducibility
#create a matrix of probabilities
#three possible outcomes, 10.000 cases
probabilities <- matrix(runif(10000*3),nrow=10000,ncol=3)
probabilities <- probabilities / Matrix::rowSums(probabilities)
The fastest way I could come up with is a combination of apply() and sample().
#row-wise sampling using these probabilities
classification <- apply(probabilities, 1, function(x) sample(1:3, 1, prob = x))
However, in what I'm doing, this is the computational bottleneck. Do you have an idea how to speed this code up / how to sample more efficiently?
Thanks!
RLave's comment that Rcpp could be the way to go is spot on (you also need RcppArmadillo for sample()); I used the following C++ code to create such a function:
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadilloExtensions/sample.h>
using namespace Rcpp;
// [[Rcpp::export]]
IntegerVector sample_matrix(NumericMatrix x, IntegerVector choice_set) {
int n = x.nrow();
IntegerVector result(n);
for ( int i = 0; i < n; ++i ) {
result[i] = RcppArmadillo::sample(choice_set, 1, false, x(i, _))[0];
}
return result;
}
I then made that function available in my R session via
Rcpp::sourceCpp("sample_matrix.cpp")
Now we can test it in R against your initial approach, as well as the other suggestions to use purrr::map() and lapply():
set.seed(1010) #reproducibility
#create a matrix of probabilities
#three possible outcomes, 10.000 cases
probabilities <- matrix(runif(10000*3),nrow=10000,ncol=3)
probabilities <- probabilities / Matrix::rowSums(probabilities)
probabilities_list <- split(probabilities, seq(nrow(probabilities)))
library(purrr)
library(microbenchmark)
microbenchmark(
apply = apply(probabilities, 1, function(x) sample(1:3, 1, prob = x)),
map = map(probabilities_list, function(x) sample(1:3, 1, prob = x)),
lapply = lapply(probabilities_list, function(x) sample(1:3, 1, prob = x)),
rcpp = sample_matrix(probabilities, 1:3),
times = 100
)
Unit: milliseconds
expr min lq mean median uq max neval
apply 307.44702 321.30051 339.85403 342.36421 350.86090 434.56007 100
map 254.69721 265.10187 282.85592 286.21680 295.48886 363.95898 100
lapply 249.68224 259.70178 280.63066 279.87273 287.10062 691.21359 100
rcpp 12.16787 12.55429 13.47837 13.81601 14.25198 16.84859 100
cld
c
b
b
a
The time savings are considerable.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With