Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow dataset questions about .shuffle, .batch and .repeat

I had a question about the use of batch, repeat and shuffle with tf.Dataset.

It is not clear to me exactly how repeat and shuffle are used. I understand that .batch will dictate how many training examples will undergo stochastic gradient descent, the uses of .repeat and .shuffle are still not clear to me.

First Question

Even after reviewing here and here, .repeat is used to reiterate over the dataset once a tf.errors.OutOfRangeError is thrown. Therefore, in my code does that mean I no longer have to implement:

try:
    while True:
        _ = sess.run(self.optimizer)

except tf.errors.OutOfRangeError:
        pass

because .repeat will automatically repeat the dataset once it is exhausted? When does it stop? or will it never stop and you just have to exit out of the while True loop once a certain number of batches (say 1000) have passed?

Second Question

Secondly, the use .shuffle makes no sense to me. Does .shuffle.batch() imply that I have, say, 100,000 samples, put 1000 randomly in a buffer with .shuffle, then batch say, 100 of them with .batch(). From my understanding the next batch will use 999 of those samples and place 1 new one in the buffer. So if my samples have no order to them, then .shuffle should be avoided all together? And if .batch is used, it would still batch 100 from those 999+1 in the buffer?

Third Question

And lastly, if I am using a separate td.dataset object for testing, what order of .shuffle.batch() should I consider? Right now I use:

sess.run(self.test_init)
try:
    while True:
        accuracy_batch = sess.run(self.accuracy)

except tf.errors.OutOfRangeError:
    pass

With:

test_data = self.test_dataset.shuffle(self.batch_size).batch(self.batch_size)

I have over 110,000 training examples at my disposal, so self.batch_size will set the number of samples I want to use to test my accuracy. So, if I wanted to just test on the whole test dataset I wouldn't use .batch? But since I have it iterating over the whole dataset with while True, it makes no difference? With the use of .shuffle I noticed my accuracies changed, but without it they were very similar. This makes me think .shuffleis randomizing the batch and may be reusing training examples?

like image 402
Jamie Dimon Avatar asked Oct 16 '22 13:10

Jamie Dimon


1 Answers

First Question:

That's correct - if you feed a dataset you no longer need to catch the OutOfRangeError.

repeat() takes an optional argument for the number of times it should repeat. This means repeat(10) will iterate over the entire dataset 10 times. If you choose to omit the argument then it will repeat indefinately

Second Question

Shuffle() (if used) should be called before batch() - we want to shuffle records not batches.

The buffer is first filled by adding your records in order then, once full, a random one is selected and emitted and a new record read from the original source.

If you have something like

ds.shuffle(1000).batch(100)

then in order to return a single batch, this last step is repeated 100 times (maintaining the buffer at 1000). Batching is a separate operation.

Third question

Generally we don't shuffle a test set at all - only the training set (We evaluate using the entire test set anyway, right? So why shuffle?).

So, if I wanted to just test on the whole test dataset I wouldn't use .batch

Hmm - not so (at least not always). You would certainly need to use batch if your whole test dataset didnt fit into memory - a common occurrence. You would want to test the whole dataset but to run the numbers in manageable bites!

like image 163
Stewart_R Avatar answered Nov 04 '22 20:11

Stewart_R