I'm modifying an existing model using RJAGS. I'd like to run chains in parallel, and occasionally check the Gelman-Rubin convergence diagnostic to see if I need to keep running. The problem is, if I need to resume running based on the diagnostic value, the recompiled chains restart from the first initialized prior values and not the position in parameter space where the chain stopped. If I do not recompile the model, RJAGS complains. Is there a way to store the positions of the chains when they stop so I can re-initialize from where I left off? Here I'll give a very simplified example.
example1.bug:
model {
for (i in 1:N) {
x[i] ~ dnorm(mu,tau)
}
mu ~ dnorm(0,0.0001)
tau <- pow(sigma,-2)
sigma ~ dunif(0,100)
}
parallel_test.R:
#Make some fake data
N <- 1000
x <- rnorm(N,0,5)
write.table(x,
file='example1.data',
row.names=FALSE,
col.names=FALSE)
library('rjags')
library('doParallel')
library('random')
nchains <- 4
c1 <- makeCluster(nchains)
registerDoParallel(c1)
jags=list()
for (i in 1:getDoParWorkers()){
jags[[i]] <- jags.model('example1.bug',
data=list('x'=x,'N'=N))
}
# Function to combine multiple mcmc lists into a single one
mcmc.combine <- function( ... ){
return( as.mcmc.list( sapply( list( ... ),mcmc ) ) )
}
#Start with some burn-in
jags.parsamples <- foreach( i=1:getDoParWorkers(),
.inorder=FALSE,
.packages=c('rjags','random'),
.combine='mcmc.combine',
.multicombine=TRUE) %dopar%
{
jags[[i]]$recompile()
update(jags[[i]],100)
jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
return(jags.samples)
}
#Check the diagnostic output
print(gelman.diag(jags.parsamples[,'mu']))
counter <- 0
#my model doesn't converge so quickly, so let's simulate doing
#this updating 5 times:
#while(gelman.diag(jags.parsamples[,'mu'])[[1]][[2]] > 1.04)
while(counter < 5)
{
counter <- counter + 1
jags.parsamples <- foreach(i=1:getDoParWorkers(),
.inorder=FALSE,
.packages=c('rjags','random'),
.combine='mcmc.combine',
.multicombine=TRUE) %dopar%
{
#Here I lose the progress I've made
jags[[i]]$recompile()
jags.samples <- coda.samples(jags[[i]],c('mu','tau'),100)
return(jags.samples)
}
}
print(gelman.diag(jags.parsamples[,'mu']))
print(summary(jags.parsamples))
stopCluster(c1)
In the output, I see:
Iterations = 1001:2000
where I know there should be > 5000 iterations. (cross-posted to stats.stackexchange.com, which may be the more appropriate venue)
Every time your JAGS model runs on the worker nodes the coda samples are returned but the state of the model is lost. So next time it recompiles, it restarts from the beginning, as you are seeing. To get around this you need to get and return the state of the model in your function (on the worker nodes) like so:
endstate <- jags[[i]]$state(internal=TRUE)
Then you need to pass this back to the worker node and re-generate the model within the worker function using jags.model() with inits=endstate (for the appropriate chain).
I would actually recommend looking at the runjags package that does all this for you. For example:
library('runjags')
parsamples <- run.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), sample=100, method='rjparallel')
summary(parsamples)
newparsamples <- extend.jags(parsamples, sample=100)
summary(parsamples)
# etc
Or even:
parsamples <- autorun.jags('example1.bug', data=list('x'=x,'N'=N), monitor=c('mu','tau'), method='rjparallel')
Version 2 of runjags will hopefully be uploaded to CRAN soon, but for now you can download binaries from: https://sourceforge.net/projects/runjags/files/runjags/
Matt
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With