Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does "shuffle" do in fit_generator in keras?

I manually built a data generator that yields a tuple of [input, target] each call. I set my generator to shuffle the training samples every epoch. Then I use fit_generator to call my generator, but confuse at the "shuffle" argument in this function:

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

From Keras API:

shuffle: Whether to shuffle the order of the batches at the beginning of each epoch. Only used with instances of Sequence (keras.utils.Sequence)

I thought "shuffle" should be the job of the generator. How can it shuffle the order of the batches when my custom generator decides which batch to be output in each iteration?

like image 933
Tu Bui Avatar asked Feb 28 '18 10:02

Tu Bui


1 Answers

As the documentation you quoted says, the shuffle argument is only relevant for a generator that implements keras.utils.Sequence.

If you are using a "simple" generator (such as keras.preprocessing.image.ImageDataGenerator, or your own custom non-Sequence generator), than that generator implements a method that return a single batch (using yield - you can learn more about it in this question). Therefore, only the generator itself controls what batch is returned.

keras.utils.Sequence was introduced to support multiprocessing:

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

To that end, you need to implement a method that return a batch by a batch index (which allows synchronization of multiple workers): __getitem__(self, idx). If you enable the shuffle argument, the __getitem__ method will be invoked with indexes in a random order.

However, you may also set it to false, and shuffle yourself by implementing the on_epoch_end method.

like image 132
Mark Loyman Avatar answered Sep 22 '22 11:09

Mark Loyman