Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallel RJAGS with convergence testing

Tags:

r

mcmc

jags

I'm modifying an existing model using RJAGS. I'd like to run chains in parallel, and occasionally check the Gelman-Rubin convergence diagnostic to see if I need to keep running. The problem is, if I need to resume running based on the diagnostic value, the recompiled chains restart from the first initialized prior values and not the position in parameter space where the chain stopped. If I do not recompile the model, RJAGS complains. Is there a way to store the positions of the chains when they stop so I can re-initialize from where I left off? Here I'll give a very simplified example.

example1.bug:

model {
  for (i in 1:N) {
      x[i] ~ dnorm(mu,tau)
  }
  mu ~ dnorm(0,0.0001)
  tau <- pow(sigma,-2)
  sigma ~ dunif(0,100)
}

parallel_test.R:

#Make some fake data
N <- 1000
x <- rnorm(N,0,5)
write.table(x,
        file='example1.data',
        row.names=FALSE,
        col.names=FALSE)

library('rjags')
library('doParallel')
library('random')

nchains <- 4
c1 <- makeCluster(nchains)
registerDoParallel(c1)

jags=list()
for (i in 1:getDoParWorkers()){
  jags[[i]] <- jags.model('example1.bug',
                          data=list('x'=x,'N'=N))
}

# Function to combine multiple mcmc lists into a single one
mcmc.combine <- function( ... ){
  return( as.mcmc.list( sapply( list( ... ),mcmc ) ) )
}

#Start with some burn-in
jags.parsamples <- foreach( i=1:getDoParWorkers(),
                           .inorder=FALSE,
                           .packages=c('rjags','random'),
                           .combine='mcmc.combine',
                           .multicombine=TRUE) %dopar%
{
  jags[[i]]$recompile()

  update(jags[[i]],100)
  jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)

  return(jags.samples)
}   

#Check the diagnostic output
print(gelman.diag(jags.parsamples[,'mu']))

counter <- 0

#my model doesn't converge so quickly, so let's simulate doing
#this updating 5 times:
#while(gelman.diag(jags.parsamples[,'mu'])[[1]][[2]] > 1.04)
while(counter < 5)
{
counter <- counter + 1
jags.parsamples <- foreach(i=1:getDoParWorkers(),
                             .inorder=FALSE,
                             .packages=c('rjags','random'),
                             .combine='mcmc.combine',
                             .multicombine=TRUE) %dopar%
  {
    #Here I lose the progress I've made
    jags[[i]]$recompile()
    jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
    return(jags.samples)
  }
}

print(gelman.diag(jags.parsamples[,'mu']))
print(summary(jags.parsamples))
stopCluster(c1)

In the output, I see:

Iterations = 1001:2000

where I know there should be > 5000 iterations. (cross-posted to stats.stackexchange.com, which may be the more appropriate venue)

like image 773
sjc Avatar asked Apr 06 '15 20:04

sjc


1 Answers

Every time your JAGS model runs on the worker nodes the coda samples are returned but the state of the model is lost. So next time it recompiles, it restarts from the beginning, as you are seeing. To get around this you need to get and return the state of the model in your function (on the worker nodes) like so:

 endstate <- jags[[i]]$state(internal=TRUE)

Then you need to pass this back to the worker node and re-generate the model within the worker function using jags.model() with inits=endstate (for the appropriate chain).

I would actually recommend looking at the runjags package that does all this for you. For example:

library('runjags')
parsamples <- run.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), sample=100, method='rjparallel')
summary(parsamples)
newparsamples <- extend.jags(parsamples, sample=100)
summary(parsamples)
# etc

Or even:

parsamples <- autorun.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), method='rjparallel')

Version 2 of runjags will hopefully be uploaded to CRAN soon, but for now you can download binaries from: https://sourceforge.net/projects/runjags/files/runjags/

Matt

like image 106
Matt Denwood Avatar answered Sep 28 '22 08:09

Matt Denwood