Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use tf.data's initializable iterator and reinitializable interator and feed data to estimator api?

All the official google tutorials use the one shot iterator for all the estimator api implementation, i couldnt find any documentation on how to use tf.data's initializable iterator and reinitializable interator instead of one shot iterator.

Can someone kindly show me how to switch between train_data and test_data using tf.data's initializable iterator and reinitializable interator. We need to run a session to use feed dict and switch the dataset in the initializable iterator, its a low level api and its confusing how to use it part of estimator api architecture

PS : I did find that google mentions "Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator."

But is there any work around within the community? or should we just stick with one shot iterator for some good reason

like image 459
AVR Avatar asked Aug 01 '18 04:08

AVR


People also ask

What is the role of the TF data API in TensorFlow?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.


1 Answers

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook. This class then have access to the session used by the tf.estimator functions.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

This is a modified version from a code I found in a blogpost by Sebastian Pölsterl. Have a look under the "Feeding data to an Estimator via the Dataset API" section.

like image 126
Olivier Dehaene Avatar answered Sep 18 '22 08:09

Olivier Dehaene