I want to freeze a pre-trained network in Keras. I found base.trainable = False
in the documentation. But I didn't understand how it works.
With len(model.trainable_weights)
I found out that I have 30 trainable weights. How can that be? The network shows total trainable params: 16,812,353.
After freezing I have 4 trainable weights.
Maybe I don't understand the difference between params and weights. Unfortunately I am a beginner in Deep Learning. Maybe someone can help me.
A Keras Model
is trainable by default - you have two means of freezing all the weights:
model.trainable = False
before compiling the modelfor layer in model.layers: layer.trainable = False
- works before & after compiling(1) must be done before compilation since Keras treats model.trainable
as a boolean flag at compiling, and performs (2) under the hood. After doing either of the above, you should see:
print(model.trainable_weights)
# []
Regarding the docs, likely outdated - see linked source code above, up-to-date.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With