Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Avoiding tf.data.Dataset.from_tensor_slices with estimator api

Tags:

I'm am trying to figure out the recommended way to use the dataset api together with the estimator api. Everything I have seen online is some variation of this:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset

which can then be passed to the estimator's train function:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )

but the dataset guide warns that:

the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

and then describes a method that involves defining placeholders which are then filled with the feed_dict:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

But if you're using the estimator api, you're not manually running the session. So how do you use the dataset api with estimators while avoiding the problems associated with from_tensor_slices()?

like image 444
Mark Avatar asked Sep 10 '18 21:09

Mark


1 Answers

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook, which has access to the session at multiple times during training and evaluation steps.

You can then use this new class to initialize the iterator has you would normally do in a classic setting. You simply need to pass this newly created hook to the training/evaluation functions or to the correct train spec.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])
like image 162
Olivier Dehaene Avatar answered Sep 28 '22 11:09

Olivier Dehaene