Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow Dataset Shuffle Each Epoch

In the manual on the Dataset class in Tensorflow, it shows how to shuffle the data and how to batch it. However, it's not apparent how one can shuffle the data each epoch. I've tried the below, but the data is given in exactly the same order the second epoch as in the first. Does anybody know how to shuffle between epochs using a Dataset?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))
like image 929
Nathan Avatar asked May 23 '17 01:05

Nathan


People also ask

What is shuffle size in TensorFlow?

For perfect shuffling, set the buffer size equal to the full size of the dataset. For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

How many batches are in TensorFlow dataset?

By default, the batch size (batch_size) is 32.


2 Answers

My environment: Python 3.6, TensorFlow 1.4.

TensorFlow has added Dataset into tf.data.

You should be cautious with the position of data.shuffle. In your code, the epochs of data has been put into the dataset's buffer before your shuffle. Here is two usable examples to shuffle dataset.

shuffle all elements

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

shuffle between batches, not shuffle in a batch

# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]
like image 70
Tom Avatar answered Oct 20 '22 00:10

Tom


It appears to me that you are using the same next_batch for both cases. So, depedening on what you really want, you may need to recreate next_batch before your second call to sess.run such as shown below, otherwise the data = data.shuffle(12) does not have any effect on the next_batch you created earlier in the code.

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)

"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
    print(sess.run(next_batch))

Please, let me know if this helps. Thanks.

like image 24
emmanuelsa Avatar answered Oct 20 '22 02:10

emmanuelsa