Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Check the total number of parameters in a PyTorch model

How do I count the total number of parameters in a PyTorch model? Something similar to model.count_params() in Keras.

like image 944
Fábio Perez Avatar asked Mar 09 '18 19:03

Fábio Perez


People also ask

How do you count parameters in a model PyTorch?

PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.

How do you find the number of parameters in a model?

Number of parameters in a CONV layer would be : ((m * n * d)+1)* k), added 1 because of the bias term for each filter. The same expression can be written as follows: ((shape of width of the filter * shape of height of the filter * number of filters in the previous layer+1)*number of filters).

How do I tell what size PyTorch model I have?

python, deep-learning, nlp, pytorch For me, the simplest way is to go to the “Files and versions” tab of a given model on the hub, and then check the size in MB/GB of the pytorch_model. bin file (or equivalently, the Flax/Tensorflow model file).

What are parameters PyTorch?

Parameters are Tensor subclasses, that have a very special property when used with Module s - when they're assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn't have such effect.


1 Answers

PyTorch doesn't have a function to calculate the total number of parameters as Keras does, but it's possible to sum the number of elements for every parameter group:

pytorch_total_params = sum(p.numel() for p in model.parameters())

If you want to calculate only the trainable parameters:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

Answer inspired by this answer on PyTorch Forums.

Note: I'm answering my own question. If anyone has a better solution, please share with us.

like image 91
Fábio Perez Avatar answered Sep 19 '22 21:09

Fábio Perez