Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Interleaving multiple TensorFlow datasets together

The current TensorFlow dataset interleave functionality is basically a interleaved flat-map taking as input a single dataset. Given the current API, what's the best way to interleave multiple datasets together? Say they have already been constructed and I have a list of them. I want to produce elements from them alternatively and I want to support lists with more than 2 datasets (i.e., stacked zips and interleaves would be pretty ugly).

Thanks! :)

@mrry might be able to help.

like image 630
eaplatanios Avatar asked Mar 01 '18 21:03

eaplatanios


People also ask

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

What is TensorFlow prefetch?

Prefetching. Prefetching overlaps the preprocessing and model execution of a training step. While the model is executing training step s , the input pipeline is reading the data for step s+1 . Doing so reduces the step time to the maximum (as opposed to the sum) of the training and the time it takes to extract the data ...

What does From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.

What is interleave TensorFlow?

TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.


1 Answers

EDIT 2: See tf.contrib.data.choose_from_datasets. It performs deterministic dataset interleaving.

EDIT: See tf.contrib.data.sample_from_datasets. Even though it performs random sampling I guess it can be useful.


Even though this is not "clean", it is the only workaround I came up with.

datasets = [tf.data.Dataset...]

def concat_datasets(datasets):
    ds0 = tf.data.Dataset.from_tensors(datasets[0])
    for ds1 in datasets[1:]:
        ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
    return ds0

ds = tf.data.Dataset.zip(tuple(datasets)).flat_map(
    lambda *args: concat_datasets(args)
)
like image 57
user2781994 Avatar answered Nov 01 '22 16:11

user2781994