I am building a data pipeline using Dataset API, but when I train of multiple GPUs and return dataset.make_one_shot_iterator().get_next()
in my input function, I get
ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy
I can follow the error message and return dataset directly, but I do not understand the purpose of iterator().get_next()
and how it works for training on single vs multiple GPU.
...
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
When using tf.data
with distribution strategy (which can be used with Keras and tf.Estimator
s), your input fn should return a tf.data.Dataset
:
def input_fn():
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset
...use input_fn...
See documentation on distribution strategy.
dataset.make_one_shot_iterator()
is useful outside of distribution strategies / higher level libraries, for example if you are using lower level libraries, or debugging / testing a dataset. For example, you can iterate all a dataset's elements like so:
dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
while True:
print(sess.run(get_next))
except tf.errors.OutOfRangeError:
break
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With