Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does `MonitoredTrainingSession()` work with "restore" and "testing mode"?

In Tensorflow, we could build and create multiple Tensorflow Sessions using Between-graph Replication for distributed training. MonitoredTrainingSession() coordinates multiple Tensorflow Sessions, and there is an argument checkpoint_dir for MonitoredTrainingSession() to restore the Tensorflow session/graph. Now I have following questions:

  1. We normally use the object of tf.train.Saver() to restore the Tensorflow graphs by saver.restore(...). But how do we restore them by using MonitoredTrainingSession()?
  2. Since we run multiple processes and each process builds and creates a Tensorflow Session for training, I wonder if we also have to run multiple processes for testing (or prediction) after training. In other words, how does MonitoredTrainingSession() work with testing (or prediction) mode?

I read Tensorflow Doc, but didn't find the answers for these 2 questions. I really appreciate if anyone has solutions. Thanks!

like image 900
Ruofan Kong Avatar asked Mar 29 '17 22:03

Ruofan Kong


1 Answers

Short answer:

  1. You need to pass the global step to the optimizer you pass to mon_sess.run. This makes it possible to both save and retrieve saved checkpoints.
  2. It is possible to run a training + cross validation session simultaneously through a single MonitoredTrainingSession. Firstly, you need to pass through the training batches and cross validation batches through separate streams of the same graph (I recommend you look up this guide for info on how to do this). Secondly, you must - to the mon_sess.run() - pass an optimizer for the training stream, as well as a parameter for the loss (/parameter you want to track) of the cross validation stream. If you want to run a test session separately from the training, simply run only the test set through the graph, and run only the test_loss (/other parameters you want to track) through the graph. For more details of how this is done, look below.

Long answer:

I will update my answer as I myself get a better view of what can be done with the tf.train.MonitoredSession (tf.train.MonitoredTrainingSession is simply creating a specialized version of tf.train.MonitoredSession, as can be seen in the source code).

Following is an example code showing how you can save checkpoints every 5 seconds to './ckpt_dir'. When interrupted, it will restart on its last saved checkpoint:

def train(inputs, labels_onehot, global_step):
    out = tf.contrib.layers.fully_connected(
                            inputs,
                            num_outputs=10,
                            activation_fn=tf.nn.sigmoid)
    loss = tf.reduce_mean(
             tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(
                            logits=out,
                            labels=labels_onehot), axis=1))
    train_op = opt.minimize(loss, global_step=global_step)
    return train_op

with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()
    inputs = ...
    labels_onehot = ...
    train_op = train(inputs, labels_onehot, global_step)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir='./ckpt_dir',
        save_checkpoint_secs=5,
        hooks=[ ... ] # Choose your hooks
    ) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

What is happening in the MonitoredTrainingSession in order to achieve this is actually three things:

  1. The tf.train.MonitoredTrainingSession creates a tf.train.Scaffold object, which works like a spider in the web; it gathers the pieces you need to train, save and load the model.
  2. It creates a tf.train.ChiefSessionCreator object. My knowledge of this one is limited, but from my understanding of it, it is used for when your tf algorithm is spread across multiple servers. My take of it is that it tells the computer running the file that it is the main computer, and that it is here that the checkpoint directory should be saved, and that loggers should log their data here, etc.
  3. It creates a tf.train.CheckpointSaverHook, which is used to save the checkpoints.

In order to make it work, the tf.train.CheckpointSaverHook and tf.train.ChiefSessionCreator must be passed the same references to checkpoint directory and scaffold. If the tf.train.MonitoredTrainingSession with its parameters in the example above were to be implemented with the 3 components above, it would look something like this:

checkpoint_dir = './ckpt_dir'

scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
    checkpoint_dir=checkpoint_dir,
    save_secs=5
    scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
    scaffold=scaffold,
    checkpoint_dir=checkpoint_dir
)

with tf.train.MonitoredSession(
    session_creator=session_creator,
    hooks=[saverhook]) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

In order to do a train + cross validation session, you can use tf.train.MonitoredSession.run_step_fn() together with partial, which runs a session call without calling any hooks. The way this looks is that you train your model for n iterations, and then you run your test set, reinitialize your iterators and go back to training your model, etc. Of course, you have to set your variables to reuse=tf.AUTO_REUSE when doing this. The way to do this in code is shown below:

from functools import partial

# Build model
...

with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
    ...

...

def step_fn(fetches, feed_dict, step_context):
    return step_context.session.run(fetches=fetches, feed_dict=feed_dict)

with tf.train.MonitoredTrainingSession(
                checkpoint_dir=...,
                save_checkpoint_steps=...,
                hooks=[...],
                ...
                ) as mon_sess:

                # Initialize iterators (assuming tf.Databases are used)
                mon_sess.run_step_fn(
                           partial(
                               step_fn, 
                               [train_it.initializer, 
                                test_it.initializer, 
                                ...
                               ], 
                               {}
                            )
                )

                while not mon_sess.should_stop():
                    # Train session
                    for i in range(n):
                        try:
                            train_results = mon_sess.run(<train_fetches>)
                        except Exception as e:
                            break

                    # Test session
                    while True:
                        try:
                            test_results = mon_sess.run(<test_fetches>)
                        except Exception as e:
                            break

                    # Reinitialize parameters
                    mon_sess.run_step_fn(
                               partial(
                                  step_fn, 
                                  [train_it.initializer, 
                                   test_it.initializer, 
                                   ...
                                  ], 
                                  {}
                               )
                    )

The partial function simply performs currying (classic function in functional programming) on the step_fn, which is used in mon_sess.run_step_fn(). The entire above code has not been tested, and you might have to reinitialize the train_it before you start the test session, but hopefully it is now clear how one could go about running both a training set and a validation set in the same run. This could furthermore be used together with the custom_scalar tool of tensorboard if you want to plot both the training curve and the test curve in the same plot.

Lastly, this is the best implementation of this functionality that I have been able to make and I personally hope that tensorflow will make implementation of this functionality a lot easier in the future, as it is quite tedious and probably not that efficient. I know that there are tools such as the Estimator that can run the train_and_evaluate function, but as this rebuilds the graph between each train- and cross validation run, it is very inefficient if you run on only a single computer. I read somewhere that Keras + tf has this functionality, but as I do not use Keras + tf, this is not an option. Anyways, I hope that this can help someone else out there struggling with the same things!

like image 165
Andreas Forslöw Avatar answered Sep 20 '22 00:09

Andreas Forslöw