I have a dataset with 11 samples. And when I choose the BATCH_SIZE
be 2, the following code will have errors:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
The problem lies in dataset = dataset.batch(batch_size)
, when the Dataset
looped into the last batch, the remaining count of samples is just 1, so is there any way to pick randomly one from the previous visited samples and generate the last batch?
@mining proposes a solution by padding the filenames.
Another solution is to use tf.contrib.data.batch_and_drop_remainder
. This will batch the data with a fixed batch size and drop the last smaller batch.
In your examples, with 11 inputs and a batch size of 2, this would yield 5 batches of 2 elements.
Here is the example from the documentation:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With