In Tensorflow, we could build and create multiple Tensorflow Sessions using Between-graph Replication
for distributed training. MonitoredTrainingSession()
coordinates multiple Tensorflow Sessions, and there is an argument checkpoint_dir
for MonitoredTrainingSession()
to restore the Tensorflow session/graph. Now I have following questions:
tf.train.Saver()
to restore the Tensorflow graphs by saver.restore(...)
. But how do we restore them by using MonitoredTrainingSession()
?MonitoredTrainingSession()
work with testing (or prediction) mode?I read Tensorflow Doc, but didn't find the answers for these 2 questions. I really appreciate if anyone has solutions. Thanks!
Short answer:
Long answer:
I will update my answer as I myself get a better view of what can be done with the tf.train.MonitoredSession (tf.train.MonitoredTrainingSession is simply creating a specialized version of tf.train.MonitoredSession, as can be seen in the source code).
Following is an example code showing how you can save checkpoints every 5 seconds to './ckpt_dir'. When interrupted, it will restart on its last saved checkpoint:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
What is happening in the MonitoredTrainingSession in order to achieve this is actually three things:
In order to make it work, the tf.train.CheckpointSaverHook and tf.train.ChiefSessionCreator must be passed the same references to checkpoint directory and scaffold. If the tf.train.MonitoredTrainingSession with its parameters in the example above were to be implemented with the 3 components above, it would look something like this:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
In order to do a train + cross validation session, you can use tf.train.MonitoredSession.run_step_fn() together with partial, which runs a session call without calling any hooks. The way this looks is that you train your model for n iterations, and then you run your test set, reinitialize your iterators and go back to training your model, etc. Of course, you have to set your variables to reuse=tf.AUTO_REUSE when doing this. The way to do this in code is shown below:
from functools import partial
# Build model
...
with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
...
...
def step_fn(fetches, feed_dict, step_context):
return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=...,
save_checkpoint_steps=...,
hooks=[...],
...
) as mon_sess:
# Initialize iterators (assuming tf.Databases are used)
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
while not mon_sess.should_stop():
# Train session
for i in range(n):
try:
train_results = mon_sess.run(<train_fetches>)
except Exception as e:
break
# Test session
while True:
try:
test_results = mon_sess.run(<test_fetches>)
except Exception as e:
break
# Reinitialize parameters
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
The partial function simply performs currying (classic function in functional programming) on the step_fn, which is used in mon_sess.run_step_fn(). The entire above code has not been tested, and you might have to reinitialize the train_it before you start the test session, but hopefully it is now clear how one could go about running both a training set and a validation set in the same run. This could furthermore be used together with the custom_scalar tool of tensorboard if you want to plot both the training curve and the test curve in the same plot.
Lastly, this is the best implementation of this functionality that I have been able to make and I personally hope that tensorflow will make implementation of this functionality a lot easier in the future, as it is quite tedious and probably not that efficient. I know that there are tools such as the Estimator that can run the train_and_evaluate function, but as this rebuilds the graph between each train- and cross validation run, it is very inefficient if you run on only a single computer. I read somewhere that Keras + tf has this functionality, but as I do not use Keras + tf, this is not an option. Anyways, I hope that this can help someone else out there struggling with the same things!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With