Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.train.MonitoredTrainingSession and reinitializable iterator from Dataset

It seems as if a MonitoredTrainingSession do some operations (logging?) before the first call to .run(..), meaning that when I do:

train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)

... no calls to ts.run ...

ts.run(init_train)

This yields the error:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element

So it seams as if the MonitoredTrainingSession is doing some operations before running the operation I feed it, making it impossible to use togeather with a reinitializable iterator from Dataset.

I am sure I am missing something and would love to hear what :-)

like image 338
Viktor Ogeman Avatar asked Aug 29 '17 18:08

Viktor Ogeman


1 Answers

Looks like there is no direct solution yet in Tensorflow. Yes it is weird that they did not give full support for Dataset API.

The reason is that the monitored session skips to run init_op when loading from the checkpoint. Hence the Iterator initializer should be a local variable.

The current work-around suggestions is given in this issue - https://github.com/tensorflow/tensorflow/issues/12859

scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
                                     init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold, 
                                       checkpoint_dir=checkpoint_dir) as sess:
    while not sess.should_stop():
        sess.run(train_op)
like image 72
Michael Jaison G Avatar answered Oct 26 '22 22:10

Michael Jaison G