After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit()
or using a custom training loop.
There seem to be 2 supported ways to save a model while training:
Save just the weights of the model, using model.save_weights()
or save_weights_only=True
with tf.keras.callbacks.ModelCheckpoint
. This seems to be preferred by most of the examples I've seen, however it has a number of major issues:
Save the entire model, optimizer, etc. using model.save()
or save_weights_only=False
. The optimizer state is saved (good) but the following issues remain:
The best workaround I've found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step)
. However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I've done feels messy as well.
Am I missing something? How are people out there saving/resuming using this API?
You can use model. stop_training parameter to stop the training.
The module EarlyStopping from keras.callbacks helps you to stop the training when a monitored quantity has stopped improving. patience=number of epochs with no improvement after which training will be stopped. Patience lets you set up the number of epochs for which you need to see when the loss stops improving.
According to the documentation of Keras, a saved model (saved with model.save (filepath)) contains the following: The architecture of the model, allowing to re-create the model The training configuration (loss, optimizer) The state of the optimizer, allowing to resume training exactly where you left off.
Using a start/stop/resume training approach with Keras, we have achieved 94.14% validation accuracy. At this point the learning rate has become so small that the corresponding weight updates are also very small, implying that the model cannot learn much more. I only allowed training to continue for 5 epochs before killing the script.
Figure 5: Upon resuming Keras training for phase 3, I only let the network train for 5 epochs because there is not significant learning progress being made. Using a start/stop/resume training approach with Keras, we have achieved 94.14% validation accuracy.
You're right, there isn't builtin support for resumability - which is exactly what motivated me to create DeepTrain. It's like Pytorch Lightning (better and worse in different regards) for TensorFlow/Keras.
Why another library? Don't we have enough? You have nothing like this; if there was, I'd not build it. DeepTrain's tailored for the "babysitting approach" to training: train fewer models, but train them thoroughly. Closely monitor each stage to diagnose what's wrong and how to fix.
Inspiration came from my own use; I'd see "validation spikes" throughout a long epoch, and couldn't afford to pause as it'd restart the epoch or otherwise disrupt the train loop. And forget knowing which batch you were fitting, or how many remain.
How's it compare to Pytorch Lightning? Superior resumability and introspection, along unique train debug utilities - but Lightning fares better in other regards. I have a comprehensive list comparison in working, will post within a week.
Pytorch support coming? Maybe. If I convince the Lightning dev team to make up for its shortcomings relative to DeepTrain, then not - otherwise probably. In the meantime, you can explore the gallery of Examples.
Minimal example:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator
ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')
dg = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val", labels_path="data/val/labels.npy")
tg = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")
tg.train()
You can KeyboardInterrupt
at any time, inspect the model, train state, data generator - and resume.
tf.keras.callbacks.experimental.BackupAndRestore
API for resuming training from interruptions has been added for tensorflow>=2.3
. It works great in my experience.
Reference: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore
tf.keras.callbacks.BackupAndRestore can take care of this.
Just use the callback function as
callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir="backup_directory")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With