Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow tf.data.Dataset API, dataset unzip function?

In tensorflow 1.12 there is the Dataset.zip function: documented here.

However, I was wondering if there is a dataset unzip function which will return back the original two datasets.

# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { 4, 5, 6 }
c = { (7, 8), (9, 10), (11, 12) }
d = { 13, 14 }

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }

# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
                            (2, 5, (9, 10)),
                            (3, 6, (11, 12)) }

# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d)) == { (1, 13), (2, 14) }

I would like to have the following

dataset = Dataset.zip((a, d)) == { (1, 13), (2, 14) }
a, d = dataset.unzip()
like image 433
Ouwen Huang Avatar asked Dec 05 '18 22:12

Ouwen Huang


People also ask

What is the role of the TF data API in TensorFlow?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

What does TF data dataset from_tensor_slices do?

Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf. data.


2 Answers

My workaround was to just use map, not sure if there might be interest in a syntax sugar function for unzip later though.

a = dataset.map(lambda a, b: a)
b = dataset.map(lambda a, b: b)
like image 129
Ouwen Huang Avatar answered Sep 27 '22 20:09

Ouwen Huang


Building on Ouwen Huang's answer, this function seems to work for arbitrary datasets:

def split_datasets(dataset):
    tensors = {}
    names = list(dataset.element_spec.keys())
    for name in names:
        tensors[name] = dataset.map(lambda x: x[name])

    return tensors
like image 45
markemus Avatar answered Sep 27 '22 18:09

markemus