Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

on_epoch_end() not called in keras fit_generator()

I followed this tutorial to generate data on-the-fly with the fit_generator() Keras method, to train my Neural Network model.

I created a generator by using the keras.utils.Sequence class .The call to fit_generator() is:

history = model.fit_generator(generator=EVDSSequence(images_train, TRAIN_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              steps_per_epoch=None, epochs=EPOCHS,
                              validation_data=EVDSSequence(images_valid, VALID_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              validation_steps=None,
                              callbacks=callbacksList, verbose=1,
                              workers=0, max_queue_size=1, use_multiprocessing=False)

steps_per_epoch is None, so the number of steps per epoch is calculated by the Keras __len()__ method.

As said in the link above:

Here, the method on_epoch_end is triggered once at the very beginning as well as at the end of each epoch. If the shuffle parameter is set to True, we will get a new order of exploration at each pass (or just keep a linear exploration scheme otherwise).

My problem is that on_epoch_end() method is called only at the very beginning, but never at the end of each epoch. So, at each epoch, the batch order is always the same.

I tried to use np.ceil instead of np.floor in __len__() method, but with no success.

Do you know why on_epoch_end is not called at the end of each epoch? Could you tell me any work-around to shuffle the order of my batches at the end (or at the beginning) of each epoch?

Many thanks!

like image 276
rainbow Avatar asked Jan 08 '20 12:01

rainbow


2 Answers

I encountered the same problem. I have no idea why this happened, but there's a way to walkaround: call on_epoch_end() within __len__(), since __len__() will be called every epoch.

like image 127
zz400 Avatar answered Sep 22 '22 05:09

zz400


Might be related to the issue: Keras model.fit not calling Sequence.on_epoch_end() #35911

A quick fix would be to use a LambdaCallback (note that I use fit which should be sufficient, as fit_generator is deprecated)

from tf.keras.callbacks import LambdaCallback

model.fit(generator, callbacks=[LambdaCallback(on_epoch_end=generator.on_epoch_end)])

Hope it helps!

like image 44
Jens C. Thuren Lindahl Avatar answered Sep 25 '22 05:09

Jens C. Thuren Lindahl