In the Tensorflow guides there are two separate places where the guide describes the input function for the Iris Data example. One input function returns just the dataset itself, while the other returns the dataset with an iterator.
From the premade Estimator guide: https://www.tensorflow.org/guide/premade_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)
From the custom estimator guide: https://www.tensorflow.org/guide/custom_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
I'm confused which one is correct, and if they both are used for different cases, when is it correct to return the dataset using an iterator?
It is recommended using pre-made Estimators when just getting started. To write a TensorFlow program based on pre-made Estimators, you must perform the following tasks: Create one or more input functions. Define the model's feature columns.
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.
The “train_op” and the scalar loss tensor are the minimum required arguments to create an “EstimatorSpec” for training.
If your input function returns a tf.data.Dataset
, an iterator is created under the hood and its get_next()
function is used to supply inputs to the model. This is somewhat hidden in the source code, see parse_input_fn_result
here.
I believe this was only implemented in a more recent update, so older tutorials still explicitly return get_next()
in their input function since it was the only option back then. There should be no difference between using either, but you can save a tiny bit of code by returning the dataset instead of the iterator.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With