I'm training a NN and would like to save the model weights every N epochs for a prediction phase. I propose this draft code, it's inspired by @grovina 's response here. Could you, please, make suggestions? Thanks in advance.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, model, N):
self.model = model
self.N = N
self.epoch = 0
def on_batch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1
Then add it to the fit call: to save weights every 5 epochs:
model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
You shouldn't need to pass a model for the callback. It already has access to the model via it's super. So remove __init__(..., model, ...)
argument and self.model = model
. You should be able to access the current model via self.model
regardless. You are also saving it on every batch end, which is not what you want, you probably want it to be on_epoch_end
.
But in any case, what you are doing can be done via naive modelcheckpoint callback. You don't need to write a custom one. You can use that as follows;
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With