Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving keras models with shared layers

Tags:

keras

I have two keras models with shared layers, whose weights I want to save in hd5 files. If I save both models individually, I think the shared layers are saved twice using the double of space in disk. How can I save it in a unique file?

Thanks!!

like image 469
Rodrigo Serna Pérez Avatar asked Dec 26 '18 10:12

Rodrigo Serna Pérez


1 Answers

You can take the shared layers and put them in a separate model. For example, if the shared layers are: layer1 and layer2, then you will have to create a model where the input layer is layer1 and the output layer is layer2. The output of layer2 will be the input to both models.

If the shared layers are the first layers in both models, then the task becomes easier and you will have three models after separating the shared layers. If the layers are between the input and the output layer of each model, then you will have to separate the layers before the shared layers in each model as well which will result in two additional models.

To achieve this easily, you can use the Keras Functional API to combine multiple models by taking the output of one model as an input to another. For example, if you have the models shared_model, model1 and model2, then you can create the two models for training or inference by taking the output of shared_model as an input for model1 and model2:

input_layer = Input(input_shape)
shared_output = shared_model(input_layer)

combined_output1 = model1(shared_output)
combined_model1 = Model(inputs=input_layer, outputs=combined_output1)

combined_output2 = model2(shared_output)
combined_model2 = Model(inputs=input_layer, outputs=combined_output2)

This way you can train combined_model1 or combined_model2 or use them for inference.

To be able to save checkpoints for the models shared_model and model1 while training combined_model1, you can use the alt-model-checkpoint library to create a callback.

like image 163
Talal Alrawajfeh Avatar answered Oct 18 '22 01:10

Talal Alrawajfeh