Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: What's the difference between state_dict and parameters()?

In order to access a model's parameters in pytorch, I saw two methods:

using state_dict and using parameters()

I wonder what's the difference, or if one is good practice and the other is bad practice.

Thanks

like image 761
Gulzar Avatar asked Feb 18 '19 11:02

Gulzar


People also ask

What is State_dict in PyTorch?

A state_dict is an integral entity if you are interested in saving or loading models from PyTorch. Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

What is model parameters () PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

What is Optimizer State_dict?

In contrast to model's state_dict , which saves learnable parameters, the optimizer's state_dict contains information about the optimizer's state (parameters to be optimized), as well as the hyperparameters used. All optimizers in PyTorch need to inherit from the base class torch.


2 Answers

The parameters() only gives the module parameters i.e. weights and biases.

Returns an iterator over module parameters.

You can check the list of the parameters as follows:

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

On the other hand, state_dict returns a dictionary containing a whole state of the module. Check its source code that contains not just the call to parameters but also buffers, etc.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are the corresponding parameter and buffer names.

Check all keys that state_dict contains using:

model.state_dict().keys()

For example, in state_dict, you'll find entries like bn1.running_mean and running_var, which are not present in .parameters().


If you only want to access parameters, you can simply use .parameters(), while for purposes like saving and loading model as in transfer learning, you'll need to save state_dict not just parameters.

like image 169
kHarshit Avatar answered Oct 25 '22 00:10

kHarshit


Besides the differences in @kHarshit 's answer, the attribute requires_grad of trainable tensors in net.parameters() is True, while False in net.state_dict()

like image 30
david Avatar answered Oct 25 '22 00:10

david