My saved state_dict does not contain all the layers that are in my model. How can I ignore the Missing key(s) in state_dict error and initialize the remaining weights?
This can be achieved by passing strict=False to load_state_dict.
load_state_dict(state_dict, strict=False)
Documentation
You can use the following snippet:
self.model.load_state_dict(dict([(n, p) for n, p in checkpoint['model'].items()]), strict=False)
where checkpoint['model'] is the pre-trained model that you want to load into your model, and self.model is the model (inherits from nn.Module) with the associated blocks that match with the saved checkpoint.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With