Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras - no good way to stop and resume training?

After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit() or using a custom training loop.

There seem to be 2 supported ways to save a model while training:

  1. Save just the weights of the model, using model.save_weights() or save_weights_only=True with tf.keras.callbacks.ModelCheckpoint. This seems to be preferred by most of the examples I've seen, however it has a number of major issues:

    • The optimizer state is not saved, meaning training resumption will not be correct.
    • Learning rate schedule is reset - this can be catastrophic for some models.
    • Tensorboard logs go back to step 0 - making logging essentually useless unless complex workarounds are implemented.
  2. Save the entire model, optimizer, etc. using model.save() or save_weights_only=False. The optimizer state is saved (good) but the following issues remain:

    • Tensorboard logs still go back to step 0
    • Learning rate schedule is still reset (!!!)
    • It is impossible to use custom metrics.
    • This doesn't work at all when using a custom training loop - custom training loops use a non-compiled model, and saving/loading a non-compiled model doesn't seem to be supported.

The best workaround I've found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step). However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I've done feels messy as well.

Am I missing something? How are people out there saving/resuming using this API?

like image 667
Daniel Avatar asked Sep 07 '20 10:09

Daniel


People also ask

How can keras training be stopped?

You can use model. stop_training parameter to stop the training.

How to stop a keras model when the loss stops improving?

The module EarlyStopping from keras.callbacks helps you to stop the training when a monitored quantity has stopped improving. patience=number of epochs with no improvement after which training will be stopped. Patience lets you set up the number of epochs for which you need to see when the loss stops improving.

What is a saved model in keras?

According to the documentation of Keras, a saved model (saved with model.save (filepath)) contains the following: The architecture of the model, allowing to re-create the model The training configuration (loss, optimizer) The state of the optimizer, allowing to resume training exactly where you left off.

What is the learning rate of Keras?

Using a start/stop/resume training approach with Keras, we have achieved 94.14% validation accuracy. At this point the learning rate has become so small that the corresponding weight updates are also very small, implying that the model cannot learn much more. I only allowed training to continue for 5 epochs before killing the script.

Why does keras only train for 5 epochs?

Figure 5: Upon resuming Keras training for phase 3, I only let the network train for 5 epochs because there is not significant learning progress being made. Using a start/stop/resume training approach with Keras, we have achieved 94.14% validation accuracy.


Video Answer


4 Answers

You're right, there isn't builtin support for resumability - which is exactly what motivated me to create DeepTrain. It's like Pytorch Lightning (better and worse in different regards) for TensorFlow/Keras.

Why another library? Don't we have enough? You have nothing like this; if there was, I'd not build it. DeepTrain's tailored for the "babysitting approach" to training: train fewer models, but train them thoroughly. Closely monitor each stage to diagnose what's wrong and how to fix.

Inspiration came from my own use; I'd see "validation spikes" throughout a long epoch, and couldn't afford to pause as it'd restart the epoch or otherwise disrupt the train loop. And forget knowing which batch you were fitting, or how many remain.

How's it compare to Pytorch Lightning? Superior resumability and introspection, along unique train debug utilities - but Lightning fares better in other regards. I have a comprehensive list comparison in working, will post within a week.

Pytorch support coming? Maybe. If I convince the Lightning dev team to make up for its shortcomings relative to DeepTrain, then not - otherwise probably. In the meantime, you can explore the gallery of Examples.


Minimal example:

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator

ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')

dg  = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val",   labels_path="data/val/labels.npy")
tg  = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")

tg.train()

You can KeyboardInterrupt at any time, inspect the model, train state, data generator - and resume.

like image 181
OverLordGoldDragon Avatar answered Oct 22 '22 01:10

OverLordGoldDragon


tf.keras.callbacks.experimental.BackupAndRestore API for resuming training from interruptions has been added for tensorflow>=2.3. It works great in my experience.

Reference: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore

like image 45
yanp Avatar answered Oct 22 '22 00:10

yanp


tf.keras.callbacks.BackupAndRestore can take care of this.

like image 38
mehran Avatar answered Oct 22 '22 00:10

mehran


Just use the callback function as

callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir="backup_directory")
like image 22
Moshiur Rahman Faisal Avatar answered Oct 22 '22 00:10

Moshiur Rahman Faisal