I want to use feedable
iterator design in tensorflow Dataset API, so I can switch to validation data after some training steps. But if I switched to validation data, it will end the whole session.
The following code demonstrate what I want to do:
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
training_ds = tf.data.Dataset.range(32).batch(4)
validation_ds = tf.data.Dataset.range(8).batch(4)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
next_element = iterator.get_next()
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
with graph.as_default():
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(training_iterator.initializer)
count_training = 0
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# print(x)
# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while not sess.should_stop():
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
# print(y)
The training data has 32 elements, batched with 4, so got 8 batches we do validation every 4 steps, so I expect:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
# 1 [validation]
# 2 [validation]
but it stops when the first validation is done:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
So, how to use this feedable
iterator in tf.MonitoredTrainingSession
?
I would suggest to catch tf.errors.OutOfRangeError
raised at the end of the validation dataset (you can also check the processing multiple epochs section in the official API for another solution using the repeat
dataset ):
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while True:
try:
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
except tf.errors.OutOfRangeError:
break
This piece of code prints:
1 [training] (4,)
2 [training] (4,)
3 [training] (4,)
4 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With