Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Replacing tf.placeholder and feed_dict with tf.data API

Tags:

I have an existing TensorFlow model which used a tf.placeholder for the model input and the feed_dict parameter of tf.Session().run to feed in data. Previously the entire dataset was read into memory and passed in this way.

I want to use a much larger dataset and take advantage of the performance improvements of the tf.data API. I've defined a tf.data.TextLineDataset and one-shot iterator from it, but I'm having a hard time figuring out how to get the data into the model to train it.

At first I tried to just define the feed_dict as a dictionary from the placeholder to iterator.get_next(), but that gave me an error saying the value of a feed cannot be a tf.Tensor object. More digging led me to understand that this is because the object returned by iterator.get_next() is already part of the graph, unlike what you would feed into feed_dict -- and that I shouldn't be trying to use feed_dict at all anyway for performance reasons.

So now I've gotten rid of the input tf.placeholder and replaced it with a parameter to the constructor of the class that defines my model; when constructing the model in my training code, I pass the output of iterator.get_next() to that parameter. This already seems a bit clunky because it breaks separation between the definition of the model and the datasets/training procedure. And I'm now getting an error saying that the Tensor representing (I believe) my model's input must be from the same graph as the Tensor from iterator.get_next().

Am I on the right track with this approach and just doing something wrong with how I set up the graph and the session, or something like that? (The datasets and model are both initialized outside of a session, and the error occurs before I attempt to create one.)

Or am I totally off base with this and need to do something different like use the Estimator API and define everything in an input function?

Here is some code demonstrating a minimal example:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])
like image 714
erobertc Avatar asked Apr 10 '18 20:04

erobertc


People also ask

What is TF Data API?

The tf.data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is TF data prefetch?

The tf.data API provides the tf.data.Dataset.prefetch transformation. It can be used to decouple the time when data is produced from the time when data is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested.

How can I pre-process input data using tf data?

When preparing data, input elements may need to be pre-processed. To this end, the tf.data API offers the tf.data.Dataset.map transformation, which applies a user-defined function to each element of the input dataset. Because input elements are independent of one another, the pre-processing can be parallelized across multiple CPU cores.

How to process multiple epochs of the same data in TF?

The tf.data API offers two main ways to process multiple epochs of the same data. The simplest way to iterate over a dataset in multiple epochs is to use the Dataset.repeat () transformation. First, create a dataset of titanic data: Applying the Dataset.repeat () transformation with no arguments will repeat the input indefinitely.


1 Answers

It took a bit for me to get my head around too. You're on the right track. The entire Dataset definition is just part of the graph. I generally create it as a different class from my Model class and pass the dataset into the Model class. I specify the Dataset class I want to load on the command line and then load that class dynamically, thereby decoupling the Dataset and the graph modularly.

Notice that you can (and should) name all the tensors in the Dataset, it really helps make things easy to understand as you pass data through the various transformations you'll need.

You can write simple test cases that pull samples from the iterator.get_next() and displays them, you'll have something like sess.run(next_element_tensor), no feed_dict as you've correctly noted.

Once you get your head around it you'll probably start liking the Dataset input pipeline. It forces you to modularize your code well, and it forces it into a structure that's easy to unit test.

Make sure you read the developers guide, there are tons of examples there:

https://www.tensorflow.org/programmers_guide/datasets

Another thing I'll note is how easy it is to work with a train and test dataset with this pipeline. That's important because you often perform data augmentation on the training dataset that you don't perform on the test dataset, from_string_handle allows you to do that and is clearly described in the guide above.

like image 176
David Parks Avatar answered Sep 19 '22 15:09

David Parks