I find reshuffle_each_iteration
parameter to tf.Dataset.shuffle
a bit confusing. Consider this code:
import tensorflow as tf
flist = ["trimg1", "trimg2", "trimg3", "trimg4"]
filenames = tf.constant(flist)
train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
it_train_x = train_x_dataset.make_initializable_iterator()
next_sample = it_train_x.get_next()
with tf.Session() as sess:
for epoch in range(3):
sess.run(it_train_x.initializer)
print("Starting epoch ", epoch)
while True:
try:
s = sess.run(next_sample)
print("Sample: ", s)
except tf.errors.OutOfRangeError:
break
The code outputs:
Starting epoch 0
Sample: b'trimg2'
Sample: b'trimg1'
Sample: b'trimg3'
Sample: b'trimg4'
Starting epoch 1
Sample: b'trimg4'
Sample: b'trimg3'
Sample: b'trimg2'
Sample: b'trimg1'
Starting epoch 2
Sample: b'trimg3'
Sample: b'trimg2'
Sample: b'trimg4'
Sample: b'trimg1'
Even though reshuffle_each_iteration
is False
, tensorflow still reshuffles dataset after it is iterated over. Is there another way to reset iterator? What's the expected behavior of reshuffle_each_iteration
?
I know I can fix the seed
and get the same order each time, the question is about how reshuffle_each_iteration
is supposed to work.
I'm also aware that more idiomatic way to do epochs is to use repeat()
, but in my case the actual count of samples will be different for each epoch.
I suspect TensorFlow still reshuffles the dataset at every iteration of your for loop because the iterator is being initialized at each iteration. Every time the iterator is initialized, the shuffle function is applied to the dataset.
The expected behavior is that the iterator is initialized once, and reshuffle_each_iteration
allows you to choose whether to reshuffle at repeats of the data (every time the original data has been iterated over).
I am not sure how to reformulate your code to deal with the variable number of samples, but here is your code modified using the repeat()
function to demonstrate my claim:
flist = ["trimg1", "trimg2", "trimg3", "trimg4"]
filenames = tf.constant(flist)
train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
train_x_dataset = train_x_dataset.repeat(4)
it_train_x = train_x_dataset.make_initializable_iterator()
next_sample = it_train_x.get_next()
with tf.Session() as sess:
sess.run(it_train_x.initializer)
while True:
try:
s = sess.run(next_sample)
print("Sample: ", s)
except tf.errors.OutOfRangeError:
break
Output:
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Whereas if I set reshuffle_each_iteration=True
, I get:
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg3
Sample: trimg2
Sample: trimg1
Sample: trimg4
Sample: trimg3
Sample: trimg1
Sample: trimg2
Sample: trimg4
Sample: trimg4
Sample: trimg1
Sample: trimg2
Sample: trimg3
Hope this helps!
Edit: Further evidence for my claim: these two test functions in the TensorFlow codebase. In this case, a one shot iterator is used, and so it is only initialized once. A batch size of 10 is used on data of size 10, and so each call of iterator.get_next()
goes over the entire source data. The code checks if each successive call of that function returns the same (shuffled) array.
This discussion on this issue might shed some more light on the intended use and expected behavior of the different iterators and perhaps help you figure out a solution to your specific problem.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With