Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow Dataset shuffle behavior on Iterator reset

I find reshuffle_each_iteration parameter to tf.Dataset.shuffle a bit confusing. Consider this code:

import tensorflow as tf

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    for epoch in range(3):
        sess.run(it_train_x.initializer)
        print("Starting epoch ", epoch)
        while True:
            try:
                s = sess.run(next_sample)
                print("Sample: ", s)
            except tf.errors.OutOfRangeError:
                break

The code outputs:

Starting epoch  0
Sample:  b'trimg2'
Sample:  b'trimg1'
Sample:  b'trimg3'
Sample:  b'trimg4'
Starting epoch  1
Sample:  b'trimg4'
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg1'
Starting epoch  2
Sample:  b'trimg3'
Sample:  b'trimg2'
Sample:  b'trimg4'
Sample:  b'trimg1'

Even though reshuffle_each_iteration is False, tensorflow still reshuffles dataset after it is iterated over. Is there another way to reset iterator? What's the expected behavior of reshuffle_each_iteration?

I know I can fix the seed and get the same order each time, the question is about how reshuffle_each_iteration is supposed to work.

I'm also aware that more idiomatic way to do epochs is to use repeat(), but in my case the actual count of samples will be different for each epoch.

like image 815
Łukasz Sromek Avatar asked Mar 08 '23 01:03

Łukasz Sromek


1 Answers

I suspect TensorFlow still reshuffles the dataset at every iteration of your for loop because the iterator is being initialized at each iteration. Every time the iterator is initialized, the shuffle function is applied to the dataset.

The expected behavior is that the iterator is initialized once, and reshuffle_each_iteration allows you to choose whether to reshuffle at repeats of the data (every time the original data has been iterated over).

I am not sure how to reformulate your code to deal with the variable number of samples, but here is your code modified using the repeat() function to demonstrate my claim:

flist = ["trimg1", "trimg2", "trimg3", "trimg4"]

filenames = tf.constant(flist)

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.shuffle(buffer_size=10, reshuffle_each_iteration=False)
train_x_dataset = train_x_dataset.repeat(4)

it_train_x = train_x_dataset.make_initializable_iterator()

next_sample = it_train_x.get_next()

with tf.Session() as sess:
    sess.run(it_train_x.initializer)
    while True:
        try:
            s = sess.run(next_sample)
            print("Sample: ", s)
        except tf.errors.OutOfRangeError:
            break

Output:

Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3

Whereas if I set reshuffle_each_iteration=True, I get:

Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg3
Sample:  trimg2
Sample:  trimg1
Sample:  trimg4
Sample:  trimg3
Sample:  trimg1
Sample:  trimg2
Sample:  trimg4
Sample:  trimg4
Sample:  trimg1
Sample:  trimg2
Sample:  trimg3

Hope this helps!

Edit: Further evidence for my claim: these two test functions in the TensorFlow codebase. In this case, a one shot iterator is used, and so it is only initialized once. A batch size of 10 is used on data of size 10, and so each call of iterator.get_next() goes over the entire source data. The code checks if each successive call of that function returns the same (shuffled) array.

This discussion on this issue might shed some more light on the intended use and expected behavior of the different iterators and perhaps help you figure out a solution to your specific problem.

like image 129
KRish Avatar answered Mar 15 '23 13:03

KRish