Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

When to use an iterator in Tensorflow Estimator

In the Tensorflow guides there are two separate places where the guide describes the input function for the Iris Data example. One input function returns just the dataset itself, while the other returns the dataset with an iterator.

From the premade Estimator guide: https://www.tensorflow.org/guide/premade_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)

From the custom estimator guide: https://www.tensorflow.org/guide/custom_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()

I'm confused which one is correct, and if they both are used for different cases, when is it correct to return the dataset using an iterator?

like image 312
D Myers Avatar asked Oct 31 '18 03:10

D Myers


People also ask

What kind of estimator model does TensorFlow recommend using for classification?

It is recommended using pre-made Estimators when just getting started. To write a TensorFlow program based on pre-made Estimators, you must perform the following tasks: Create one or more input functions. Define the model's feature columns.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

Which value is required as an input to an evaluation EstimatorSpec?

The “train_op” and the scalar loss tensor are the minimum required arguments to create an “EstimatorSpec” for training.


1 Answers

If your input function returns a tf.data.Dataset, an iterator is created under the hood and its get_next() function is used to supply inputs to the model. This is somewhat hidden in the source code, see parse_input_fn_result here.

I believe this was only implemented in a more recent update, so older tutorials still explicitly return get_next() in their input function since it was the only option back then. There should be no difference between using either, but you can save a tiny bit of code by returning the dataset instead of the iterator.

like image 167
xdurch0 Avatar answered Nov 07 '22 05:11

xdurch0