I encountered some problems when I continued training my model and visualized the progress on tensorboard.
My question is how do I resume training from the same step without specifying any epoch manually? If possible, simply by loading the saved model, it somehow could read the global_step
from the optimizer saved and continue training from there.
I have provided some codes below to reproduce similar errors.
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.models import load_model
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])
model.save('./final_model.h5', include_optimizer=True)
del model
model = load_model('./final_model.h5')
model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])
You can run the tensorboard
by using the command:
tensorboard --logdir ./logs
To continue training a loaded model with checkpoints, we simply rerun the model. fit function with the callback still parsed. This however overwrites the currently saved best model, so make sure to change the checkpoint file path if this is undesired.
Model compilation & fitting dataTensorboard, or tensorboard , in its own is the implementation as defined by the Keras API. In our case, we save logs at . \logs , generate weight histograms after each epochs, and do write weight images to our logs.
You can set the parameter initial_epoch
in the function model.fit()
to the number of the epoch you want your training to start from. Take into account that the model trains until the epoch of index epochs
is reached (and not a number of iterations given by epochs
).
In your example, if you want to train for 10 epochs more, it should be:
model.fit(x_train, y_train, initial_epoch=9, epochs=19, callbacks=[Tensorboard()])
It will allow you to visualise your plots on Tensorboard in a correct manner. More extensive information about these parameters can be found in the docs.
It's very simple. Create checkpoints while training the model and then use those checkpoints to resume training from where you left of.
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])
model.save('./final_model.h5', include_optimizer=True)
model = load_model('./final_model.h5')
callbacks = list()
tensorboard = Tensorboard()
callbacks.append(tensorboard)
file_path = "model-{epoch:02d}-{loss:.4f}.hdf5"
# now here you can create checkpoints and save according to your need
# here period is the no of epochs after which to save the model every time during training
# another option is save_weights_only, for your case it should be false
checkpoints = ModelCheckpoint(file_path, monitor='loss', verbose=1, period=1, save_weights_only=False)
callbacks.append(checkpoints)
model.fit(x_train, y_train, epochs=10, callbacks=callbacks)
After this just load the checkpoint from where you want to resume training again
model = load_model(checkpoint_of_choice)
model.fit(x_train, y_train, epochs=10, callbacks=callbacks)
And you are done.
Let me know if you have more questions about this.
Here is sample code in case someone needs it. It implements the idea proposed by Abhinav Anand:
mca = ModelCheckpoint(join(dir, 'model_{epoch:03d}.h5'),
monitor = 'loss',
save_best_only = False)
tb = TensorBoard(log_dir = join(dir, 'logs'),
write_graph = True,
write_images = True)
files = sorted(glob(join(fold_dir, 'model_???.h5')))
if files:
model_file = files[-1]
initial_epoch = int(model_file[-6:-3])
print('Resuming using saved model %s.' % model_file)
model = load_model(model_file)
else:
model = nn.model()
initial_epoch = 0
model.fit(x_train,
y_train,
epochs = 100,
initial_epoch = initial_epoch,
callbacks = [mca, tb])
Replace nn.model()
with your own function for defining the model.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With