Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to load BertforSequenceClassification models weights into BertforTokenClassification model?

Initially, I have a fine-tuned BERT base cased model using a text classification dataset and I have used BertforSequenceClassification class for this.

from transformers import BertForSequenceClassification, AdamW, BertConfig

# Load BertForSequenceClassification, the pretrained BERT model with a single 
# linear classification layer on top. 
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

Now I want to use this fine-tuned BERT model weights for Named Entity Recognition and I have to use BertforTokenClassification class for this. I'm unable to figure out how to load the fine-tuned BERT model weights into the new model created using BertforTokenClassification.

Thanks in advance.......................

like image 430
Mr. NLP Avatar asked May 22 '26 22:05

Mr. NLP


1 Answers

You can get weights from the bert inside the first model and load into the bert inside the second:

new_model = BertForTokenClassification(config=config)
new_model.bert.load_state_dict(model.bert.state_dict())
like image 193
Anastasiia Iurshina Avatar answered May 25 '26 07:05

Anastasiia Iurshina



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!