I'm saving my session state like so:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
When I later restore I want to get the value of the global_step for the checkpoint I restore from. This is in order to set some hyper parameters from it.
The hacky way to do this would be to run through and parse the file names in the checkpoint directory. But surly there has to be a better, built in way to do this?
Restoring Models The first thing to do when restoring a TensorFlow model is to load the graph structure from the ". meta" file into the current graph. The current graph could be explored using the following command tf. get_default_graph() .
This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11.
The Saver class adds ops to save and restore variables to and from checkpoints. It also provides convenience methods to run these ops. Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver .
General pattern is to have a global_step
variable to keep track of steps
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
Then you can save with
saver.save(sess, save_path, global_step=global_step)
When you restore, the value of global_step
is restored as well
This is a bit of a hack, but the other answers did not work for me at all
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
Update 9/2017
I'm not sure if this started working due to updates, but the following method seems to be effective in getting global_step to update and load properly:
Create two ops. One to hold global_step and another to increment it:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
Now in your training loop run the increment op every time you run your training op.
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
If you ever want to retrieve you global step value as an integer at any point, just use the following command after loading the model:
sess.run(global_step)
This can be useful for creating filenames or calculating what your current epoch is without having a second tensorflow Variable for holding that value. For instance, calculating the current epoch on loading would be something like:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With