Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to count total number of trainable parameters in a tensorflow model?

Is there a function call or another way to count the total number of parameters in a tensorflow model?

By parameters I mean: an N dim vector of trainable variables has N parameters, a NxM matrix has N*M parameters, etc. So essentially I'd like to sum the product of the shape dimensions of all the trainable variables in a tensorflow session.

like image 911
j314erre Avatar asked Jul 02 '16 13:07

j314erre


People also ask

What is the total number of trainable parameters?

Thus, this feed-forward neural network has 94 connections in all and thus 94 trainable parameters.

How do you count parameters in a model PyTorch?

PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.

What are trainable variables in TensorFlow?

From my understanding, trainable means that the value could be changed during sess.run() That is not the definition of a trainable variable. Any variable can be modified during a sess. run() (That's why they are variables and not constants).

How do I get a model summary in TensorFlow?

Model summaryCall model. summary() to print a useful summary of the model, which includes: Name and type of all layers in the model. Output shape for each layer.


2 Answers

Loop over the shape of every variable in tf.trainable_variables().

total_parameters = 0 for variable in tf.trainable_variables():     # shape is an array of tf.Dimension     shape = variable.get_shape()     print(shape)     print(len(shape))     variable_parameters = 1     for dim in shape:         print(dim)         variable_parameters *= dim.value     print(variable_parameters)     total_parameters += variable_parameters print(total_parameters) 

Update: I wrote an article to clarify the dynamic/static shapes in Tensorflow because of this answer: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

like image 190
nessuno Avatar answered Sep 22 '22 11:09

nessuno


I have an even shorter version, one line solution using using numpy:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 
like image 20
Michael Gygli Avatar answered Sep 19 '22 11:09

Michael Gygli