Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does batch, repeat, and shuffle do with TensorFlow Dataset?

People also ask

What does dataset repeat do?

repeat() method of tf. data. Dataset class is used for repeating the tensors for a given count times in dataset. If repeat(count=None) or repeat(count=-1) is specified than dataset is repeated indefinitely.

What does TF data dataset From_tensor_slices do?

With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...

What does prefetch do in TensorFlow?

prefetch transformation. It can be used to decouple the time when data is produced from the time when data is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested.


Update: Here is a small collaboration notebook for demonstration of this answer.

Imagine, you have a dataset: [1, 2, 3, 4, 5, 6], then:

How ds.shuffle() works

dataset.shuffle(buffer_size=3) will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset. We could image it like this:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Let's assume that the entry 2 was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4:

2 <= [1,3,4] <= [5,6]

We continue reading till nothing is left:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

How ds.repeat() works

As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error. That's where ds.repeat() comes into play. It will re-initialize the dataset, making it again like this:

[1,2,3] <= [4,5,6]

What will ds.batch() produce

The ds.batch() will take first batch_size entries and make a batch out of them. So, batch size of 3 for our example dataset will produce two batch records:

[2,1,5]
[3,6,4]

As we have a ds.repeat() before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random(). What should be taken into account is that 6 will never be present in the first batch, due to the size of the random buffer.


The following methods in tf.Dataset :

  1. repeat( count=0 ) The method repeats the dataset count number of times.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
  3. batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.