Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the global_step when restoring checkpoints in Tensorflow?

Tags:

tensorflow

I'm saving my session state like so:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

When I later restore I want to get the value of the global_step for the checkpoint I restore from. This is in order to set some hyper parameters from it.

The hacky way to do this would be to run through and parse the file names in the checkpoint directory. But surly there has to be a better, built in way to do this?

like image 588
Daniel Slater Avatar asked Mar 20 '16 11:03

Daniel Slater


People also ask

How do I restore a saved model in TensorFlow?

Restoring Models The first thing to do when restoring a TensorFlow model is to load the graph structure from the ". meta" file into the current graph. The current graph could be explored using the following command tf. get_default_graph() .

What is Ckpt file in TensorFlow?

This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11.

What does TF train saver do?

The Saver class adds ops to save and restore variables to and from checkpoints. It also provides convenience methods to run these ops. Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver .


2 Answers

General pattern is to have a global_step variable to keep track of steps

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

Then you can save with

saver.save(sess, save_path, global_step=global_step)

When you restore, the value of global_step is restored as well

like image 73
Yaroslav Bulatov Avatar answered Oct 18 '22 19:10

Yaroslav Bulatov


This is a bit of a hack, but the other answers did not work for me at all

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

Update 9/2017

I'm not sure if this started working due to updates, but the following method seems to be effective in getting global_step to update and load properly:

Create two ops. One to hold global_step and another to increment it:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

Now in your training loop run the increment op every time you run your training op.

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

If you ever want to retrieve you global step value as an integer at any point, just use the following command after loading the model:

sess.run(global_step)

This can be useful for creating filenames or calculating what your current epoch is without having a second tensorflow Variable for holding that value. For instance, calculating the current epoch on loading would be something like:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
like image 27
Lawrence Du Avatar answered Oct 18 '22 21:10

Lawrence Du