Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Should I return dataset directly or should i use one_shot iterator instead?

I am building a data pipeline using Dataset API, but when I train of multiple GPUs and return dataset.make_one_shot_iterator().get_next() in my input function, I get

ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy

I can follow the error message and return dataset directly, but I do not understand the purpose of iterator().get_next() and how it works for training on single vs multiple GPU.

...

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size = batch_size)
    dataset = dataset.cache()

    dataset = dataset.prefetch(buffer_size=None)

    return dataset.make_one_shot_iterator().get_next()

return _input_fn
like image 731
Val Avatar asked Feb 04 '19 16:02

Val


1 Answers

When using tf.data with distribution strategy (which can be used with Keras and tf.Estimators), your input fn should return a tf.data.Dataset:

def input_fn():
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size = batch_size)
  dataset = dataset.cache()

  dataset = dataset.prefetch(buffer_size=None)
  return dataset

...use input_fn...

See documentation on distribution strategy.

dataset.make_one_shot_iterator() is useful outside of distribution strategies / higher level libraries, for example if you are using lower level libraries, or debugging / testing a dataset. For example, you can iterate all a dataset's elements like so:

dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
  while True:
    print(sess.run(get_next))
  except tf.errors.OutOfRangeError:
    break
like image 69
rachelim Avatar answered Oct 10 '22 23:10

rachelim