Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

save model weights at the end of every N epochs

I'm training a NN and would like to save the model weights every N epochs for a prediction phase. I propose this draft code, it's inspired by @grovina 's response here. Could you, please, make suggestions? Thanks in advance.

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

Then add it to the fit call: to save weights every 5 epochs:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
like image 764
Belkacem Thiziri Avatar asked Jul 05 '18 08:07

Belkacem Thiziri


1 Answers

You shouldn't need to pass a model for the callback. It already has access to the model via it's super. So remove __init__(..., model, ...) argument and self.model = model. You should be able to access the current model via self.model regardless. You are also saving it on every batch end, which is not what you want, you probably want it to be on_epoch_end.

But in any case, what you are doing can be done via naive modelcheckpoint callback. You don't need to write a custom one. You can use that as follows;

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
like image 103
umutto Avatar answered Sep 25 '22 06:09

umutto