Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to ignore and initialize Missing key(s) in state_dict

Tags:

python

pytorch

My saved state_dict does not contain all the layers that are in my model. How can I ignore the Missing key(s) in state_dict error and initialize the remaining weights?

like image 809
Neabfi Avatar asked Jul 23 '20 15:07

Neabfi


2 Answers

This can be achieved by passing strict=False to load_state_dict.

load_state_dict(state_dict, strict=False)

Documentation

like image 55
Neabfi Avatar answered Sep 29 '22 06:09

Neabfi


You can use the following snippet:

self.model.load_state_dict(dict([(n, p) for n, p in checkpoint['model'].items()]), strict=False)

where checkpoint['model'] is the pre-trained model that you want to load into your model, and self.model is the model (inherits from nn.Module) with the associated blocks that match with the saved checkpoint.

like image 25
inverted_index Avatar answered Sep 29 '22 05:09

inverted_index