I am converting some legacy code to use the Dataset API - this code uses feed_dict
to feed one batch to the train operation (actually three times) and then recalculates the losses for display using the same batch. So I need to have an iterator that returns the exact same batch two (or several) times. Unfortunately, I can't seem to find a way of doing it with tensorflow datasets - is it possible?
You can repeat individual elements of a Dataset
using Dataset.flat_map()
, Dataset.from_tensors()
and Dataset.repeat()
together. For example, to repeat elements twice:
NUM_REPEATS = 2
dataset = tf.data.Dataset.range(10) # ...or the output of `.batch()`, etc.
# Repeat each element of `dataset` NUM_REPEATS times.
dataset = dataset.flat_map(
lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With