Tensorflow programmer's guide recommends using feedable iterator to switch between training and validation dataset without reinitializing the iterator. It mainly requires to feed the handle to choose between them.
How to use it along with tf.train.MonitoredTrainingSession
?
The following method fails with "RuntimeError: Graph is finalized and cannot be modified." error.
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
How to achieve both the convenience of MonitoredTrainingSession and iterating training and validation datasets simultaneously?
A MapDataset is a dataset that applies a transform to a source dataset.
Creates a dataset that asynchronously prefetches elements from input_dataset .
tensorflow_datasets ( tfds ) defines a collection of datasets ready-to-use with TensorFlow. Each dataset is defined as a tfds. core.
I got the answer from the Tensorflow GitHub issue - https://github.com/tensorflow/tensorflow/issues/12859
The solution is to invoke the iterator.string_handle()
before creating the MonitoredSession
.
import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
@Michael Jaison G answer is correct. However, it does not work when you also want to use certain session_run_hooks that need to evaluate parts of the graph, like e.g. LoggingTensorHook or SummarySaverHook. The example below will cause an error:
import tensorflow as tf
dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)
iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()
pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()
summary_hook = tf.train.SummarySaverHook(save_steps=5,
output_dir="summaries", summary_op=tf.summary.merge_all())
with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
for step in range(10):
feat = sess.run(feature, feed_dict={handle: handle_train})
pred_ = sess.run(pred, feed_dict={handle: handle_train})
print('train: ', feat)
print('pred: ', pred_)
if step % 3 == 0:
print('val', sess.run(feature, feed_dict={handle: handle_val}))
This will fail with error:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
[[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
[[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
The reason being that the hook will try to evaluate the graph already upon the first session.run([iter_train_handle, iter_val_handle]) which obviously does not contain a handle in the feed_dict yet.
The workaround solution being to overwrite the hooks that cause the problem and changing the code in before_run and after_run to only evaluate on session.run calls containing the handle in the feed_dict (you can access the feed_dict of the current session.run call via the run_context argument of before_run and after_run)
Or you can use the latest master of Tensorflow (post-1.4) which adds a run_step_fn function to MonitoredSession which allows you to specify the following step_fn which will avoid the error (on the expense of evaluating the if statement TrainingIteration number of times ...)
def step_fn(step_context):
if handle_train is None:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
return step_context.run_with_hooks(fetches=..., feed_dict=...)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With