Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Modify trained model architecture and continue training Keras

I want to train a model in a sequential manner. That is I want to train the model initially with a simple architecture and once it is trained, I want to add a couple of layers and continue training. Is it possible to do this in Keras? If so, how?

I tried to modify the model architecture. But until I compile, the changes are not effective. Once I compile, all the weights are re-initialized and I lose all the trained information.

All the questions in web and SO I found are either about loading a pre-trained model and continuing training or modifying the architecture of pre-trained model and then only test it. I didn't find anything related to my question. Any pointers are also highly appreciated.

PS: I'm using Keras in tensorflow 2.0 package.

like image 312
Nagabhushan S N Avatar asked Mar 09 '20 01:03

Nagabhushan S N


1 Answers

Without knowing the details of your model, the following snippet might help:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input

# Train your initial model
def get_initial_model():
    ...
    return model

model = get_initial_model()
model.fit(...)
model.save_weights('initial_model_weights.h5')

# Use Model API to create another model, built on your initial model
initial_model = get_initial_model()
initial_model.load_weights('initial_model_weights.h5')

nn_input = Input(...)
x = initial_model(nn_input)
x = Dense(...)(x)  # This is the additional layer, connected to your initial model
nn_output = Dense(...)(x)

# Combine your model
full_model = Model(inputs=nn_input, outputs=nn_output)

# Compile and train as usual
full_model.compile(...)
full_model.fit(...)

Basically, you train your initial model, save it. And reload it again, and wrap it together with your additional layers using the Model API. If you are not familiar with Model API, you can check out the Keras documentation here (afaik the API remains the same for Tensorflow.Keras 2.0).

Note that you need to check if your initial model's final layer's output shape is compatible with the additional layers (e.g. you might want to remove the final Dense layer from your initial model if you are just doing feature extraction).

like image 66
Toukenize Avatar answered Oct 12 '22 14:10

Toukenize