Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I load a partial pretrained pytorch model?

I'm trying to get a pytorch model running on a sentence classification task. As I am working with medical notes I am using ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT) and would like to use its pre-trained weights. Unfortunately the ClinicalBert model only classifies text into 1 binary label while I have 281 binary labels. I am therefore trying to implement this code https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb where the end classifier after bert is 281 long.

How can I load the pre-trained Bert weights from the ClinicalBert model without loading the classification weights?

Naively trying to load the weights from the pretrained ClinicalBert weights I get the following error:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

I currently tried to replace the from_pretrained function from the pytorch_pretrained_bert package and pop the classifier weights and biases like this:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

And I get the following error message: INFO - modeling_diagnosis - Weights of BertForMultiLabelSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']

In the end I would like to load the bert embeddings from the clinicalBert pretrained weights and have the top classifier weights initialized randomly.

like image 518
happyrabbit Avatar asked Apr 14 '20 15:04

happyrabbit


People also ask

What is State_dict in PyTorch?

A state_dict is an integral entity if you are interested in saving or loading models from PyTorch. Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

What is a PT file PyTorch?

A common PyTorch convention is to save tensors using . pt file extension. PyTorch preserves storage sharing across serialization. See Saving and loading tensors preserves views for more details. The 1.6 release of PyTorch switched torch.


1 Answers

Removing the keys in the state dict before loading is a good start. Assuming you're using nn.Module.load_state_dict to load the pretrained weights then you'll also need to set the strict=False argument to avoid errors from unexpected or missing keys. This will ignore entries in the state_dict that aren't present in the model (unexpected keys) and, more importantly for you, will leave the missing entries with their default initialization (missing keys). For safety you can check the return value of the method to verify the weights in question are part of the missing keys and that there aren't any unexpected keys.

like image 80
jodag Avatar answered Sep 20 '22 03:09

jodag