In the manual on the Dataset class in Tensorflow, it shows how to shuffle the data and how to batch it. However, it's not apparent how one can shuffle the data each epoch. I've tried the below, but the data is given in exactly the same order the second epoch as in the first. Does anybody know how to shuffle between epochs using a Dataset?
n_epochs = 2
batch_size = 3
data = tf.contrib.data.Dataset.range(12)
data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()
sess = tf.Session()
for _ in range(4):
print(sess.run(next_batch))
print("new epoch")
data = data.shuffle(12)
for _ in range(4):
print(sess.run(next_batch))
For perfect shuffling, set the buffer size equal to the full size of the dataset. For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer.
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.
By default, the batch size (batch_size) is 32.
My environment: Python 3.6, TensorFlow 1.4.
TensorFlow has added Dataset
into tf.data
.
You should be cautious with the position of data.shuffle
. In your code, the epochs of data has been put into the dataset
's buffer before your shuffle
. Here is two usable examples to shuffle dataset.
shuffle all elements
# shuffle all elements
import tensorflow as tf
n_epochs = 2
batch_size = 3
buffer_size = 5
dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
sess = tf.Session()
print("epoch 1")
for _ in range(4):
print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
print(sess.run(next_batch))
OUTPUT:
epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10 2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11 9 10]
shuffle between batches, not shuffle in a batch
# shuffle between batches, not shuffle in a batch
import tensorflow as tf
n_epochs = 2
batch_size = 3
buffer_size = 5
dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
sess = tf.Session()
print("epoch 1")
for _ in range(4):
print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
print(sess.run(next_batch))
OUTPUT:
epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]
It appears to me that you are using the same next_batch
for both cases. So, depedening on what you really want, you may need to recreate next_batch
before your second call to sess.run
such as shown below, otherwise the data = data.shuffle(12)
does not have any effect on the next_batch
you created earlier in the code.
n_epochs = 2
batch_size = 3
data = tf.contrib.data.Dataset.range(12)
data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()
sess = tf.Session()
for _ in range(4):
print(sess.run(next_batch))
print("new epoch")
data = data.shuffle(12)
"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
print(sess.run(next_batch))
Please, let me know if this helps. Thanks.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With