Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow error: restore checkpoint file

Tags:

tensorflow

I built up my own convolutional neural network, in which I track the moving averages of all trainable variables (tensorflow 1.0):

variable_averages = tf.train.ExponentialMovingAverage(
        0.9999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
train_op = tf.group(apply_gradient_op, variables_averages_op)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
summary_op = tf.summary.merge(summaries)
init = tf.global_variables_initializer()
sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False))
sess.run(init)
# start queue runners
tf.train.start_queue_runners(sess=sess)

summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

# training loop
start_time = time.time()
for step in range(FLAGS.max_steps):
        _, loss_value = sess.run([train_op, loss])
        duration = time.time() - start_time
        start_time = time.time()
        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if step % 1 == 0:
            # print current model status
            num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
            examples_per_sec = num_examples_per_step/duration
            sec_per_batch = duration/FLAGS.num_gpus
            format_str = '{} step{}, loss {}, {} examples/sec, {} sec/batch'
            print(format_str.format(datetime.now(), step, loss_value, examples_per_sec, sec_per_batch))
        if step % 50 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
        if step % 10 == 0 or step == FLAGS.max_steps:
            print('save checkpoint')
            # save checkpoint file
            checkpoint_file = os.path.join(FLAGS.train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=step)

This workes fine and checkpoint files are saved (saver version V2). Then I try to restore the checkpoints in a nother script for evaluating the model. There I have this piece of code

# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(
    MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

where I get the error "NotFoundError (see above for traceback): Key conv1/Variable/ExponentialMovingAverage not found in checkpoint" where conv1/variable/ is a variable scope.

This error ocuurs even before I try to restore the variables. Can you please help to solve it?

Thanks in advance

TheJude

like image 892
TheJude Avatar asked Mar 09 '17 14:03

TheJude


People also ask

What is a checkpoint file in Tensorflow?

b) Checkpoint file: This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt. However, Tensorflow has changed this from version 0.11.

What is a Ckpt file?

The Checkpoint file is a VSAM KSDS that contains checkpoint information generated by the DTF during execution of a copy operation. The Checkpoint file consists of variable length records, one per Process that has checkpointing specified. The average record length is 256 bytes.

What does TF train Saver () do?

The Saver class adds ops to save and restore variables to and from checkpoints. It also provides convenience methods to run these ops. Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a Saver .


1 Answers

I solved it in this way:
Call tf.reset_default_graph() before create second ExponentialMovingAverage(...) in the graph.

# reset the graph before create a new ema
tf.reset_default_graph()
# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

It took me 2 hours...

like image 130
Tao Avatar answered Oct 29 '22 19:10

Tao