Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does shuffle = 'batch' argument of the .fit() layer work in the background?

When I train the model using the .fit() layer there is the argument shuffle preset to True.

Let's say that my dataset has 100 samples and that the batch size is 10. When I set shuffle = True then keras first randomly selects randomly the samples (now the 100 samples have a different order) and on the new order it will start creating the batches: batch 1: 1-10, batch 2: 11-20 etc.

If I set shuffle = 'batch' how is it supposed to work in the background? Intuitively and using the previous example of 100 samples dataset with batch size = 10 my guess would be that keras first allocates the samples to the batches (i.e. batch 1: samples 1-10 following the dataset original order, batch 2: 11-20 following the dataset original order as well, batch 3 ... so on so forth) and then shuffles the order of the batches. So the model now will be trained on the randomly ordered batches say for example: 3 (contains samples 21 - 30), 4 (contains samples 31 - 40), 7 (contains samples 61 - 70), 1 (contains samples 1 - 10), ... (I made up the order of the batches).

Is my thinking right or am I missing something?

Thanks!

like image 809
mrt Avatar asked Nov 08 '22 18:11

mrt


1 Answers

Looking at the implementation at this link (line 349 of training.py) the answer seems to be positive.

Try this code for checking:

import numpy as np
def batch_shuffle(index_array, batch_size):
    """Shuffles an array in a batch-wise fashion.
    Useful for shuffling HDF5 arrays
    (where one cannot access arbitrary indices).
    # Arguments
        index_array: array of indices to be shuffled.
        batch_size: integer.
    # Returns
        The `index_array` array, shuffled in a batch-wise fashion.
    """
    batch_count = int(len(index_array) / batch_size)
    # to reshape we need to be cleanly divisible by batch size
    # we stash extra items and reappend them after shuffling
    last_batch = index_array[batch_count * batch_size:]
    index_array = index_array[:batch_count * batch_size]
    index_array = index_array.reshape((batch_count, batch_size))
    np.random.shuffle(index_array)
    index_array = index_array.flatten()
    return np.append(index_array, last_batch)


x = np.array(range(100))
x_s = batch_shuffle(x,10)
like image 140
user2614596 Avatar answered Nov 22 '22 08:11

user2614596