Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Interleaving tf.data.Datasets

Tags:

tensorflow

I'm trying to use tf.data.Dataset to interleave two datasets but having problems doing so. Given this simple example:

ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()

What is ... to produce an output like 0, 1, 2, 3...9?

It would seem like dataset.interleave() would be relevant but I haven't been able to formulate the statement in a way that doesn't generate an error.

like image 985
RobR Avatar asked Nov 17 '17 04:11

RobR


People also ask

What is dataset interleave?

TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

What does tf data dataset do?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.


1 Answers

MattScarpino is on the right track in his comment. You can use Dataset.zip() along with Dataset.flat_map() to flatten a multi-element dataset:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
    lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
        tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()

Having said this, your intuition about using Dataset.interleave() is pretty sensible. We're investigating ways that you can do this more easily.


PS. As an alternative, you can use Dataset.interleave() to solve the problem if you change how ds0 and ds1 are defined:

dataset = tf.data.Dataset.range(2).interleave(
    lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)
like image 155
mrry Avatar answered Sep 20 '22 04:09

mrry