Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to pad to fixed BATCH_SIZE in tf.data.Dataset?

I have a dataset with 11 samples. And when I choose the BATCH_SIZE be 2, the following code will have errors:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

The problem lies in dataset = dataset.batch(batch_size), when the Dataset looped into the last batch, the remaining count of samples is just 1, so is there any way to pick randomly one from the previous visited samples and generate the last batch?

like image 426
mining Avatar asked Dec 23 '22 10:12

mining


1 Answers

@mining proposes a solution by padding the filenames.

Another solution is to use tf.contrib.data.batch_and_drop_remainder. This will batch the data with a fixed batch size and drop the last smaller batch.

In your examples, with 11 inputs and a batch size of 2, this would yield 5 batches of 2 elements.

Here is the example from the documentation:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
like image 79
Olivier Moindrot Avatar answered Feb 12 '23 12:02

Olivier Moindrot