I have a keras NN that I want to train and validate using two sets of data, and then test the ultimate performance of using a third set. In order to avoid having to rerun the training every time I restart my google colab runtime or want to change my test data, I want to save the final state of the model after training in one script and then load it again in another script.
I've looked everywhere and it seems that model.save("content/drive/My Drive/Directory/ModelName", save_format='tf') should do the trick, but even though it outputs INFO:tensorflow:Assets written to: content/drive/My Drive/Directory/ModelName/assets nothing appears in my Google Drive, so I assume it isn't actually saving.
Please can someone help me solve this issue?
Thanks in advance!
The standard way of saving and retrieving your model's state after Google Colab terminated your connection is to use a feature called ModelCheckpoint. This is a callback in Keras that would run after each epoch and it will save your model for instance any time there's an improvement. Here's is the steps needed to accomplish what you want:
Use this code in order to connect to Google Drive:
from google.colab import drive
drive.mount('/content/gdrive')
Then you'll presented with a link that you should go to and after authorizing Google Colab by copying the given code to the text box as shown below:

ModelCheckpointThis is how you could define your ModelCheckpoint's callback:
from keras.callbacks import *
filepath="/content/gdrive/My Drive/MyCNN/epochs:{epoch:03d}-val_acc:{val_acc:.3f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
callback in while you're training the modelThen you need to tell your model that after each epoch run this functionality for me to save the model's state.
model.fit(X_train, y_train,
batch_size=64,
epochs=epochs,
verbose=1,
validation_data=(X_val, y_val),
callbacks=callbacks_list)
Finally after your session was terminated, you can load your previous model's state by simply running the following code. Don't forget to re-define your model first and only load weights at this stage.
model.load_weights('/content/gdrive/My Drive/MyCNN/epochs:047-val_acc:0.905.hdf5'
Hope that this answers your question.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With