I would like to manage my training with a tf.estimator.Estimator
but have some trouble to use it alongside the tf.data
API.
I have something like this:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
As I can't use a make_one_shot_iterator
for my use case, my issue is that input_fn
contains an iterator that should be initialized within model_fn
(here, I use tf.train.Scaffold
to initialize local ops).
Also, I understood that we can't only use input_fn = iterator.get_next
otherwise the other ops will not be added to the same graph.
What is the recommended way to initialize the iterator?
As of TensorFlow 1.5, it is possible to make input_fn
return a tf.data.Dataset
, e.g.:
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
return dataset
See c294fcfd.
For previous versions, you can add the iterator's initializer in the tf.GraphKeys.TABLE_INITIALIZERS
collections and rely on the default initializer.
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With