Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get epoch num info from tf.train.string_input_producer

Tags:

tensorflow

If reading files using string_input_producer, like

filename_queue = tf.train.string_input_producer(
  files, 
  num_epochs=num_epochs,
  shuffle=shuffle)

how can I get epoch num info during training(I want to show this info during training) I tried below

run 
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/epochs:0')

will always the same as the limit epoch num.

run
tf.get_default_graph().get_tensor_by_name('input_train/input_producer/limit_epochs/CountUpTo:0')

will each time add 1..

Both can not get correct epoch num during training.

Another thing is ,if retrain from existing model, can I got the epoch num info already trained?

like image 246
allen Avatar asked Aug 23 '16 12:08

allen


1 Answers

I think the right approach here is to define a global_step variable that you pass to your optimizer (or you can increment it manually).

The TensorFlow Mechanics 101 tutorial provides an example:

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

Now global_step will be incremented each time the train_op runs. Since you know the size of your dataset and your batch size, you will know what epoch you're currently at.

When you save your model with a tf.train.Saver(), the global_step variable will also be saved. When you restore your model, you can just call global_step.eval() to get back the step value where you left off.

I hope this helps!

like image 72
chillinger Avatar answered Nov 12 '22 16:11

chillinger