Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Proper way to iterate tf.data.Dataset in session for 2.0

I have downloaded some *.tfrecord data from the youtube-8m project. You can download a 'small' portion of the data with this command:

curl data.yt8m.org/download.py | shard=1,100 partition=2/video/train mirror=us python

I am trying to get an idea of how to use the new tf.data API. I would like to become familiar with the typical ways people iterate through datasets. I have been using the guide on TF website and this slide: Derek Murray's Slides

Here is how I define the dataset:

# Use interleave() and prefetch() to read many files concurrently.
files = tf.data.Dataset.list_files("./youtube_vids/*.tfrecord")
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100),
                           cycle_length=8)

# Use num_parallel_calls to parallelize map().
dataset = dataset.map(lambda record: tf.parse_single_example(record, feature_map),
                     num_parallel_calls=2) #

# put in x,y output form
dataset = dataset.map(lambda x: (x['mean_rgb'], x['id']))

# shuffle
dataset = dataset.shuffle(10000)

#one epoch
dataset = dataset.repeat(1)
dataset = dataset.batch(200)

#Use prefetch() to overlap the producer and consumer.
dataset = dataset.prefetch(10)

Now, I know in eager execution mode I can just

for x,y in dataset:
    x,y

However, when I attempt to create an iterator as follows:

# A one-shot iterator automatically initializes itself on first use.
iterator = dset.make_one_shot_iterator()

# The return value of get_next() matches the dataset element type.
images, labels = iterator.get_next()

And run with session

with tf.Session() as sess:

    # Loop until all elements have been consumed.
    try:
        while True:
            r = sess.run(images)
    except tf.errors.OutOfRangeError:
        pass

I get the warning

Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.

So, here is my question:

What is the proper way to iterate through a dataset within a session? Is it just a matter of v1 and v2 differences?

Also, the advice to pass the dataset directly to an estimator implies that the input function also has an iterator defined as in Derek Murray's slides above, correct?

like image 422
leonard Avatar asked May 31 '19 18:05

leonard


People also ask

Which are the three main methods of getting data into a TensorFlow program?

Feeding: Python code provides the data when running each step. Reading from files: an input pipeline reads the data from files at the beginning of a TensorFlow graph. Preloaded data: a constant or variable in the TensorFlow graph holds all the data (for small data sets).

What does TF data dataset From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.

Is TF data dataset a generator?

data. Dataset objects as generators for the training of a machine learning model on Tensorflow, with parallelized processing. The tf. data pipeline is now the gold standard for building an efficient data pipeline for machine learning applications with TensorFlow.


1 Answers

As for Estimator API, no you don't have to specify iterator, just pass dataset object as input function.

def input_fn(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.shuffle().repeat()
    dataset = dataset.map(parse_func)
    dataset = dataset.batch()
    return dataset

estimator.train(input_fn=lambda: input_fn())

In TF 2.0 dataset became iterable, so, just as warning message says, you can use

for x,y in dataset:
    x,y
like image 88
Sharky Avatar answered Nov 10 '22 01:11

Sharky