My saved state_dict does not contain all the layers that are in my model. How can I ignore the Missing key(s) in state_dict error and initialize the remaining weights?
This can be achieved by passing strict=False
to load_state_dict
.
load_state_dict(state_dict, strict=False)
Documentation
You can use the following snippet:
self.model.load_state_dict(dict([(n, p) for n, p in checkpoint['model'].items()]), strict=False)
where checkpoint['model']
is the pre-trained model that you want to load into your model, and self.model
is the model (inherits from nn.Module
) with the associated blocks that match with the saved checkpoint.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With