I'm trying to use tf.data.Dataset to interleave two datasets but having problems doing so. Given this simple example:
ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()
What is ...
to produce an output like 0, 1, 2, 3...9
?
It would seem like dataset.interleave() would be relevant but I haven't been able to formulate the statement in a way that doesn't generate an error.
TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.
The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.
MattScarpino is on the right track in his comment. You can use Dataset.zip()
along with Dataset.flat_map()
to flatten a multi-element dataset:
ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)
# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
tf.data.Dataset.from_tensors(x1)))
iter = dataset.make_one_shot_iterator()
val = iter.get_next()
Having said this, your intuition about using Dataset.interleave()
is pretty sensible. We're investigating ways that you can do this more easily.
PS. As an alternative, you can use Dataset.interleave()
to solve the problem if you change how ds0
and ds1
are defined:
dataset = tf.data.Dataset.range(2).interleave(
lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With