Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently randomly drawing from a multivariate normal distribution

Tags:

r

rcpp

Just wondering if anyone has ever encountered the problem where he/she needs to randomly draw from a very high dimensional multivariate normal distribution (say dimension = 10,000), as the rmvnorm function of the mvtnorm package is impractical for that.

I know this article has an Rcpp implementation for the dmvnorm function of the mvtnorm package, so I was wondering if something equivalent exists for rmvnorm?

like image 243
user1701545 Avatar asked Mar 29 '14 23:03

user1701545


1 Answers

Here's a quick comparison of mvtnorm::rmvnorm and an Rcpp implementation given here by Ahmadou Dicko. The times presented are for 100 draws from a multivariate normal distribution with dimension ranging from 500 to 2500. From the graph below you can probably infer the time required for dimension of 10000. Times include the overhead of generating the random mu vector and the diag matrix, but these are consistent across approaches and are trivial for the dimensions in question (e.g. 0.2 sec for diag(10000)).

library(Rcpp)
library(RcppArmadillo)
library(inline)
library(mvtnorm)

code <- '
using namespace Rcpp;
int n = as<int>(n_);
arma::vec mu = as<arma::vec>(mu_);
arma::mat sigma = as<arma::mat>(sigma_);
int ncols = sigma.n_cols;
arma::mat Y = arma::randn(n, ncols);
return wrap(arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma));
'

rmvnorm.rcpp <- 
  cxxfunction(signature(n_="integer", mu_="numeric", sigma_="matrix"), code,
              plugin="RcppArmadillo", verbose=TRUE)

rcpp.time <- sapply(seq(500, 5000, 500), function(x) {
  system.time(rmvnorm.rcpp(100, rnorm(x), diag(x)))[3]  
})

mvtnorm.time <- sapply(seq(500, 2500, 500), function(x) {
  system.time(rmvnorm(100, rnorm(x), diag(x)))[3]  
})


plot(seq(500, 5000, 500), rcpp.time, type='o', xlim=c(0, 5000),
     ylim=c(0, max(mvtnorm.time)), xlab='dimension', ylab='time (s)')

points(seq(500, 2500, 500), mvtnorm.time, type='o', col=2)

legend('topleft', legend=c('rcpp', 'mvtnorm'), lty=1, col=1:2, bty='n')

enter image description here

like image 68
jbaums Avatar answered Oct 30 '22 14:10

jbaums