Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast sampling from Truncated Normal Distribution using Rcpp and openMP

Tags:

r

rcpp

openmp

UPDATE:

I tried to implement Dirk's suggestions. Comments? I am busy right now at JSM, but I'd like to get some feedback before knitting an Rmd for the gallery. I switched back from Armadillo to normal Rcpp, as it didn't add any value. Scalar versions with R:: are quite nice. I should maybe put in a parameter n for the number of draws if mean/sd are entered as scalar, not as vectors of the desired output length.


There are lots of MCMC application that require drawing samples from truncated Normal distributions. I built on an existing implementation of the TN and added parallel computation to it.

Issues:

  1. Does anyone see further potential speed improvements? In the last case from the benchmark, rtruncnorm is sometimes faster. The Rcpp implementation is always faster than existing packages, but can it be improved even further?
  2. I ran it within a complex model I can't share, and my R session crashed. However, I cannot systematically reproduce it, so it could have been another part of the code. If someone is working with the TN, please test it and let me know. Update: I haven't had issues with the updated code, but let me know.

How I put things together: To my knowledge, the fastest implementation is not on CRAN, but the source code can be downloaded OSU stat. Competing implementations in msm and truncorm were slower in my benchmarks. The trick is to efficiently adjust proposal distributions, where the Exponential works nicely for the tails of the truncated Normal. So I took Chris' code, "Rcpp'ed" it and added some openMP spice to it. The dynamic schedule is optimal here, as sampling can take more or less time depending on the boundaries. One thing I found nasty: lots of the statistical distributions are based on the NumericVector type, when I wanted to work with doubles. I just coded my way around that.

Heres the Rcpp code:

#include <Rcpp.h>
#include <omp.h>


// norm_rs(a, b)
// generates a sample from a N(0,1) RV restricted to be in the interval
// (a,b) via rejection sampling.
// ======================================================================

// [[Rcpp::export]]

double norm_rs(double a, double b)
{
   double  x;
   x = Rf_rnorm(0.0, 1.0);
   while( (x < a) || (x > b) ) x = norm_rand();
   return x;
}

// half_norm_rs(a, b)
// generates a sample from a N(0,1) RV restricted to the interval
// (a,b) (with a > 0) using half normal rejection sampling.
// ======================================================================

// [[Rcpp::export]]

double half_norm_rs(double a, double b)
{
   double   x;
   x = fabs(norm_rand());
   while( (x<a) || (x>b) ) x = fabs(norm_rand());
   return x;
}

// unif_rs(a, b)
// generates a sample from a N(0,1) RV restricted to the interval
// (a,b) using uniform rejection sampling. 
// ======================================================================

// [[Rcpp::export]]

double unif_rs(double a, double b)
{
   double xstar, logphixstar, x, logu;

   // Find the argmax (b is always >= 0)
   // This works because we want to sample from N(0,1)
   if(a <= 0.0) xstar = 0.0;
   else xstar = a;
   logphixstar = R::dnorm(xstar, 0.0, 1.0, 1.0);

   x = R::runif(a, b);
   logu = log(R::runif(0.0, 1.0));
   while( logu > (R::dnorm(x, 0.0, 1.0,1.0) - logphixstar))
   {
      x = R::runif(a, b);
      logu = log(R::runif(0.0, 1.0));
   }
   return x;
}

// exp_rs(a, b)
// generates a sample from a N(0,1) RV restricted to the interval
// (a,b) using exponential rejection sampling.
// ======================================================================

// [[Rcpp::export]]

double exp_rs(double a, double b)
{
  double  z, u, rate;

//  Rprintf("in exp_rs");
  rate = 1/a;
//1/a

   // Generate a proposal on (0, b-a)
   z = R::rexp(rate);
   while(z > (b-a)) z = R::rexp(rate);
   u = R::runif(0.0, 1.0);

   while( log(u) > (-0.5*z*z))
   {
      z = R::rexp(rate);
      while(z > (b-a)) z = R::rexp(rate);
      u = R::runif(0.0,1.0);
   }
   return(z+a);
}




// rnorm_trunc( mu, sigma, lower, upper)
//
// generates one random normal RVs with mean 'mu' and standard
// deviation 'sigma', truncated to the interval (lower,upper), where
// lower can be -Inf and upper can be Inf.
//======================================================================

// [[Rcpp::export]]
double rnorm_trunc (double mu, double sigma, double lower, double upper)
{
int change;
 double a, b;
 double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725;
 double z, tmp, lograt;

 change = 0;
 a = (lower - mu)/sigma;
 b = (upper - mu)/sigma;

 // First scenario
 if( (a == R_NegInf) || (b == R_PosInf))
   {
     if(a == R_NegInf)
       {
     change = 1;
     a = -b;
     b = R_PosInf;
       }

     // The two possibilities for this scenario
     if(a <= 0.45) z = norm_rs(a, b);
     else z = exp_rs(a, b);
     if(change) z = -z;
   }
 // Second scenario
 else if((a * b) <= 0.0)
   {
     // The two possibilities for this scenario
     if((R::dnorm(a, 0.0, 1.0,1.0) <= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) <= logt1))
       {
     z = norm_rs(a, b);
       }
     else z = unif_rs(a,b);
   }
 // Third scenario
 else
   {
     if(b < 0)
       {
     tmp = b; b = -a; a = -tmp; change = 1;
       }

     lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0);
     if(lograt <= logt2) z = unif_rs(a,b);
     else if((lograt > logt1) && (a < t3)) z = half_norm_rs(a,b);
     else z = exp_rs(a,b);
     if(change) z = -z;
   }
   double output;
   output = sigma*z + mu;
 return (output);
}


// rtnm( mu, sigma, lower, upper, cores)
//
// generates one random normal RVs with mean 'mu' and standard
// deviation 'sigma', truncated to the interval (lower,upper), where
// lower can be -Inf and upper can be Inf.
// mu, sigma, lower, upper are vectors, and vectorized calls of this function
// speed up computation
// cores is an intege, representing the number of cores to be used in parallel
//======================================================================


// [[Rcpp::export]]

Rcpp::NumericVector rtnm(Rcpp::NumericVector mus, Rcpp::NumericVector sigmas, Rcpp::NumericVector lower, Rcpp::NumericVector upper, int cores){
  omp_set_num_threads(cores);
  int nobs = mus.size();
  Rcpp::NumericVector out(nobs);
  double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725;
    double a,b, z, tmp, lograt;

     int  change;

  #pragma omp parallel for schedule(dynamic)   
  for(int i=0;i<nobs;i++) {  

     a = (lower(i) - mus(i))/sigmas(i);
     b = (upper(i) - mus(i))/sigmas(i);
     change=0;
     // First scenario
     if( (a == R_NegInf) || (b == R_PosInf))
       {
         if(a == R_NegInf)
           {
              change = 1;
              a = -b;
              b = R_PosInf;
           }

         // The two possibilities for this scenario
         if(a <= 0.45) z = norm_rs(a, b);
         else z = exp_rs(a, b);
         if(change) z = -z;
       }
     // Second scenario
     else if((a * b) <= 0.0)
       {
         // The two possibilities for this scenario
         if((R::dnorm(a, 0.0, 1.0,1.0) <= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) <= logt1))
           {
                z = norm_rs(a, b);
           }
         else z = unif_rs(a,b);
       }

     // Third scenario
     else
       {
         if(b < 0)
           {
                tmp = b; b = -a; a = -tmp; change = 1;
           }

         lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0);
         if(lograt <= logt2) z = unif_rs(a,b);
         else if((lograt > logt1) && (a < t3)) z = half_norm_rs(a,b);
         else z = exp_rs(a,b);
         if(change) z = -z;
       }
    out(i)=sigmas(i)*z + mus(i);          
  }

return(out);
}

And here is the benchmark:

libs=c("truncnorm","msm","inline","Rcpp","RcppArmadillo","rbenchmark")
if( sum(!(libs %in% .packages(all.available = TRUE)))>0){ install.packages(libs[!(libs %in% .packages(all.available = TRUE))])}
for(i in 1:length(libs)) {library(libs[i],character.only = TRUE,quietly=TRUE)}


#needed for openMP parallel
Sys.setenv("PKG_CXXFLAGS"="-fopenmp")
Sys.setenv("PKG_LIBS"="-fopenmp")

#no of cores for openMP version
cores = 4

#surce code from same dir
Rcpp::sourceCpp('truncnorm.cpp')


#sample size
nn=1000000


bb= 100
aa=-100
benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3    )[,1:4]

aa=0 
benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3    )[,1:4]

aa=2
benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3    )[,1:4]

aa=50
benchmark( rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3    )[,1:4]

Several benchmark runs are necessary as the speed depends on the upper/lower boundaries. For different cases, different parts of the algorithm kick in.

like image 670
Inferrator Avatar asked Jul 29 '13 02:07

Inferrator


Video Answer


1 Answers

Really quick comments:

  1. if you include RcppArmadillo.h you do not need to include Rcpp.h -- in fact, you should not and we even test that

  2. rep(oneDraw, n) makes n calls. I would write a function to be called once that returns you n draws -- it will be faster as you save yourself n-1 function call overheads

  3. Your comment on lots of the statistical distributions are based on the NumericVector type, when I wanted to work with doubles may reveal some misunderstanding: NumericVector is our convenient proxy class for internal R types: no copies. You are free to use std::vector<double> or whichever form you prefer.

  4. I know little about truncated normals so I cannot comment on the specifics of your algorithms.

  5. Once you have it worked out consider a post for the Rcpp Gallery.

like image 187
Dirk Eddelbuettel Avatar answered Sep 22 '22 11:09

Dirk Eddelbuettel