Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to continue training model using ModelCheckpoint of Keras

Im a new user of Keras. I have a question about training procedure using Keras.

Due to the time limitation of my server (each job can only run in less than 24h), I have to train my model using multiple 10-epoch period.

At 1st period of training, after 10 epochs, the weights of best model is stored using ModelCheckpoint of Keras.

conf = dict()
conf['nb_epoch'] = 10
callbacks = [
             ModelCheckpoint(filepath='/1st_{epoch:d}_{val_loss:.5f}.hdf5',
             monitor='val_loss', save_best_only=True,
             save_weights_only=False, verbose=0)
            ]   

Assume I get best model: '1st_10_1.00000.hdf5'. Next, I continue training my model using 10 epochs and store the weights of best model as follows.

model.load_weights('1st_10_1.00000.hdf5')
model.compile(...)
callbacks = [
             ModelCheckpoint(filepath='/2nd_{epoch:d}_{val_loss:.5f}.hdf5',
             monitor='val_loss', save_best_only=True,
             save_weights_only=False, verbose=0)
            ]

But I have a problem. 1st epoch of the second training gives val_loss of 1.20000, and the script produces a model '2nd_1_1.20000.hdf5'. Obviously, the new val_loss is greater than the best val_loss of the first training (1.00000). And the following epochs of second training seem to be trained based on the model '2nd_1_1.20000.hdf5', not '1st_10_1.00000.hdf5'.

'2nd_1_1.20000.hdf5'
'2nd_1_2.15000.hdf5'
'2nd_1_3.10000.hdf5'
'2nd_1_4.05000.hdf5'
...

I think it is a waste not using the better result of first training period. Anyone can point me out the way to fix it, or the way to tell program that it should use the best model of the previous training period? Many thanks in advance!

like image 702
Nghia Duong Avatar asked Oct 29 '22 08:10

Nghia Duong


2 Answers

Interesting case, could be a great improvement... I don't think the API currently support such solution, beside making your own callback function.

I don't think it would be that hard. You could base it of the original modelcheckpointcallback class and just change.

This line: https://github.com/fchollet/keras/blob/master/keras/callbacks.py#L390

It stores the current best value of the item being logget, it is being initialised in a if statement as either -inf/inf depending on the situation.

In your case you would have to find a way to read the filename of the file, do some string manipulation, and add that instead.

I would suggest adding it as a separate statement.. or as an else if.

To avoid messing with the core code too much.

Hope it helped..

like image 120
J.Down Avatar answered Nov 15 '22 10:11

J.Down


I ran into the identical problem, and did not see your question until I had also asked the question. Based on the feedback I got, I wrote a simple callback that saves and restores the best training values (such as val_loss). You can find it here: How to preserve metric values over training sessions in Keras?

like image 37
MadOverlord Avatar answered Nov 15 '22 09:11

MadOverlord