Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I get the number of trainable parameters of a model in Keras?

Tags:

python

keras

I am setting trainable=False in all my layers, implemented through the Model API, but I want to verify whether that is working. model.count_params() returns the total number of parameters, but is there any way in which I can get the total number of trainable parameters, other than looking at the last few lines of model.summary()?

like image 449
Prabaha Avatar asked Jul 12 '17 00:07

Prabaha


People also ask

How many trainable parameters are there in the model?

Thus, this feed-forward neural network has 94 connections in all and thus 94 trainable parameters.

How do you count parameters in a model PyTorch?

PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.

What are trainable parameters?

Trainable parameters are those which value is adjusted/modified during training as per their gradient. In Batch Normalization layer we have below mentioned trainable params: gamma: It's a scaling factor. beta: a learned offset factor.

What does trainable mean in keras?

A Keras Model is trainable by default - you have two means of freezing all the weights: model. trainable = False before compiling the model. for layer in model. layers: layer.


Video Answer


3 Answers

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

The above snippet can be discovered in the end of layer_utils.print_summary() definition, which summary() is calling.


Edit: more recent version of Keras has a helper function count_params() for this purpose:

from keras.utils.layer_utils import count_params

trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
like image 54
tuomastik Avatar answered Oct 23 '22 18:10

tuomastik


For TensorFlow 2.0:

import tensorflow.keras.backend as K

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
like image 32
Danylo Baibak Avatar answered Oct 23 '22 17:10

Danylo Baibak


For tensorflow.keras this works for me. Its from the tensorflow github code for the function print_layer_summary_with_connections() in layer_utils.py

import numpy as np
from tensorflow.python.util import object_identity

def count_params(weights):
    return int(sum(np.prod(p.shape.as_list())
      for p in object_identity.ObjectIdentitySet(weights)))

if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
else:
    trainable_count = count_params(model.trainable_weights)

print (trainable_count)
like image 1
satvik choudhary Avatar answered Oct 23 '22 16:10

satvik choudhary