All the official google tutorials use the one shot iterator for all the estimator api implementation, i couldnt find any documentation on how to use tf.data's initializable iterator and reinitializable interator instead of one shot iterator.
Can someone kindly show me how to switch between train_data and test_data using tf.data's initializable iterator and reinitializable interator. We need to run a session to use feed dict and switch the dataset in the initializable iterator, its a low level api and its confusing how to use it part of estimator api architecture
PS : I did find that google mentions "Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator."
But is there any work around within the community? or should we just stick with one shot iterator for some good reason
The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.
TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.
To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook. This class then have access to the session used by the tf.estimator functions.
Here is quick example that you can adapt to your needs :
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self, session, coord):
self.iterator_initializer_func(session)
def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)
dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})
return next_example, next_label
return input_fn, iterator_initializer_hook
...
train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
...
estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])
This is a modified version from a code I found in a blogpost by Sebastian Pölsterl. Have a look under the "Feeding data to an Estimator via the Dataset API" section.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With