Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.contrib.data.Dataset repeat with shuffle, notice epoch end, mixed epochs?

Tags:

tensorflow

About the tf.contrib.data.Dataset (from TensorFlow 1.2, see here and here) usage: When I use repeat (for multiple epochs) together with shuffle (as read_batch_features does internally), how will I notice when some epochs ends, and what the current epoch is? Also, when the epoch ends, will the ShuffleDataset wait first to dequeue everything or will it already be filled with more data from the next epoch? In the last epoch, or if I don't use repeat, will the ShuffleDataset dequeue all remaining data, like tf.RandomShuffleQueue dequeueing does after close?

My current solution, which also gives me more control: I would not use repeat but go once over the data and use ShuffleDataset to get shuffling like RandomShuffleQueue, and then at some point I get OutOfRangeError and I know that I reached the end of the epoch. Then I reinitializable the iterator, like it is described here.

like image 731
Albert Avatar asked May 23 '17 10:05

Albert


People also ask

What does dataset shuffle do?

shuffle() method randomly shuffles a tensor along its first dimension. Parameters: buffer_size: This is the number of elements from which the new dataset will be sampled.

What is TF data dataset?

The tf. data API introduces a tf. data. Dataset abstraction that represents a sequence of elements, in which each element consists of one or more components. For example, in an image pipeline, an element might be a single training example, with a pair of tensor components representing the image and its label.

Which data type is returned by TensorFlow Datasets?

load will return the tuple ( tf. data. Dataset , tfds.


1 Answers

The behavior of Dataset.shuffle() depends on where in your pipeline it appears relative to the Dataset.repeat():

  • If you shuffle before the repeat, the sequence of outputs will first produce all records from epoch i, before any record from epoch i + 1.

  • If you shuffle after the repeat, the sequence of outputs may produce records from epoch i before or after epoch i + 1 (and, epoch i + k, with probability that increases with the buffer_size and decreases with k).

If you want to perform some computation between epochs, and avoid mixing data from different epochs, it is probably easiest to avoid repeat() and catch the OutOfRangeError at the end of each epoch.

There are some more interesting pipelines you could build to track the epoch number. For example, you could encode an epoch number as a component of each element:

dataset = (
    Dataset.range(None).flat_map(lambda epoch_num: 
        Dataset.zip(
            (Dataset.from_tensors(epoch_num).repeat(),  # Infinite repeat of `epoch_num`.
             ...,  # Definition of a Dataset over a single epoch.
            )
        )
    )
)

...where ... is the expression that defines a Dataset for a single epoch, and includes batching and shuffling.

like image 170
mrry Avatar answered Oct 16 '22 20:10

mrry