Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use tf.data's initializable iterators within a tf.estimator's input_fn?

I would like to manage my training with a tf.estimator.Estimator but have some trouble to use it alongside the tf.data API.

I have something like this:

def model_fn(features, labels, params, mode):
  # Defines model's ops.
  # Initializes with tf.train.Scaffold.
  # Returns an tf.estimator.EstimatorSpec.

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.

  iterator = dataset.make_initializable_iterator()

  return iterator.get_next()

estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)

As I can't use a make_one_shot_iterator for my use case, my issue is that input_fn contains an iterator that should be initialized within model_fn (here, I use tf.train.Scaffold to initialize local ops).

Also, I understood that we can't only use input_fn = iterator.get_next otherwise the other ops will not be added to the same graph.

What is the recommended way to initialize the iterator?

like image 951
guillaumekln Avatar asked Jul 10 '17 12:07

guillaumekln


1 Answers

As of TensorFlow 1.5, it is possible to make input_fn return a tf.data.Dataset, e.g.:

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.
  return dataset

See c294fcfd.


For previous versions, you can add the iterator's initializer in the tf.GraphKeys.TABLE_INITIALIZERS collections and rely on the default initializer.

tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
like image 67
guillaumekln Avatar answered Nov 02 '22 11:11

guillaumekln