Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What are saved in optimizer's state_dict? what "state","param_groups" stands for?

When we use Adam optimizer, if we want to continue train a network from a pretrained model, we not only should load "model.state_dict", but also "optimizer.state_dict". And, if we modified our network's structure, we should also modify saved optimizer's state_dict to make our loading successful.

But I don't understand some params in saved "optimizer.state_dict". like optim_dict["state"] (dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']))and optim_dict['param_groups'][0]['params']. There are many of numbers like these:

 b['optimizer_state_dict']['state'].keys()
Out[71]: dict_keys([140623218628000, 140623218628072, 140623218628216, 140623218628360, 140623218628720, 140623218628792, 140623218628936, 140623218629080, 140623218629656, 140623218629728, 140623218629872, 140623218630016, 140623218630376, 140623218630448, 140623218716744, 140623218716816, 140623218717392, 140623218717464, 140623218717608, 140623218717752, 140623218718112, 140623218718184, 140623218718328, 140623218718472, 140623218719048, 140623218719120, 140623218719264, 140623218719408, 140623218719768, 140623218719840, 140623218719984, 140623218720128, 140623218720704, 140623209943112, 140623209943256, 140623209943400, 140623209943760, 140623209943832, 140623209943976, 140623209944120, 140623209944696, 140623209944768, 140623209944912, 140623209945056, 140623209945416, 140623209945488, 140623209945632, 140623209945776, 140623209946352, 140623209946424, 140623209946568, 140623209946712, 140623209947072, 140623210041416, 140623210041560, 140623210041704, 140623244033768, 140623244033840, 140623244033696, 140623244033912, 140623244033984, 140623244070984, 140623244071056, 140623244071128, 140623429501576, 140623244071200, 140623244071272, 140623244071344, 140623244071416, 140623244071488, 140623244071560, 140623244071632, 140623244071848, 140623244071920, 140623244072064, 140623244072208, 140623244072424, 140623244072496, 140623244072640, 140623244072784, 140623244073216, 140623244073288, 140623244073432, 140623244073576, 140623244073792, 140623244073864, 140623244074008, 140623244074152, 140623244074584, 140623244074656, 140623244074800, 140623244074944, 140623218540760, 140623218540832, 140623218540976, 140623218541120, 140623218541552, 140623218541624, 140623218541768, 140623218541912, 140623218542128, 140623218542200, 140623218542344, 140623218542488, 140623218542920, 140623218542992, 140623218543136, 140623218543280, 140623218543496, 140623218543568, 140623218543712, 140623218543856, 140623218544288, 140623218544360, 140623218544504, 140623218626632, 140623218626992, 140623218627064, 140623218627208, 140623218627352, 140623218627784, 140623218629440, 140623218717176, 140623218718832, 140623218720488, 140623209944480, 140623209946136, 140623210043000])

In [44]: b['optimizer_state_dict']['state'][140623218628072].keys()
Out[44]: dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])

In [45]: b['optimizer_state_dict']['state'][140623218628072]['exp_avg'].shape
Out[45]: torch.Size([480])
like image 288
Eric Kani Avatar asked Jun 08 '20 11:06

Eric Kani


1 Answers

In contrast to model's state_dict, which saves learnable parameters, the optimizer's state_dict contains information about the optimizer’s state (parameters to be optimized), as well as the hyperparameters used.


All optimizers in PyTorch need to inherit from the base class torch.optim.Optimizer. It requires two entries:

  • params (iterable) – an iterable of torch.Tensors or dicts. Specifies what Tensors should be optimized.
  • defaults (dict): a dict containing default values of optimization options (used when a parameter group doesn’t specify them).

In addition to that, optimizers also support specifying per-parameter options.

To do this, instead of passing an iterable of Tensors, pass in an iterable of dicts. Each of them will define a separate parameter group, and should contain a params key, containing a list of parameters belonging to it.

Consider an example,

optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

Here, we have provided the a) params, b) default hyperparameters: lr, momentum, and c) a parameter group. In this case, the model.base’s parameters will use the default learning rate of 1e-2, model.classifier’s parameters will use a learning rate of 1e-3, and a momentum of 0.9 will be used for all parameters.


The step (optimizer.step()) performs a single optimization step (parameter update), which changes the state of the optimizer.


Now, coming to optimizer's state_dict, it returns the state of the optimizer as a dict. It contains two entries:

  • state - a dict holding current optimization state.
  • param_groups - a dict containing all parameter groups (as discussed above)

Some of the hyperparameters are specific to the optimizer or model used e.g. (used in Adam)

  • exp_avg: exponential moving average of gradient values
  • exp_avg_sq: exponential moving average of squared gradient values
like image 114
kHarshit Avatar answered Oct 31 '22 03:10

kHarshit