Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Having 6 labels instead of 2 in Hugging Face BertForSequenceClassification

I was just wondering if it is possibel to extend the HuggingFace BertForSequenceClassification model to more than 2 labels. The docs say, we can pass positional arguments, but it seems like "labels" is not working. Does anybody has an idea?

Model assignment

labels = th.tensor([0,0,0,0,0,0], dtype=th.long).unsqueeze(0)
print(labels.shape)
modelBERTClass = transformers.BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', 
    labels=labels
    )

l = [module for module in modelBERTClass.modules()]
l

Console Output

torch.Size([1, 6])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-122-fea9a36402a6> in <module>()
      3 modelBERTClass = transformers.BertForSequenceClassification.from_pretrained(
      4     'bert-base-uncased',
----> 5     labels=labels
      6     )
      7 

/usr/local/lib/python3.6/dist-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    653 
    654         # Instantiate model.
--> 655         model = cls(config, *model_args, **model_kwargs)
    656 
    657         if state_dict is None and not from_tf:

TypeError: __init__() got an unexpected keyword argument 'labels'
like image 506
Alex Avatar asked Oct 20 '25 11:10

Alex


1 Answers

You can set the output shape of the classification layer with from_pretrained via the num_labels parameter:

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
print(model.classifier.parameters)

Output:

Linear(in_features=768, out_features=6, bias=True)
like image 90
cronoik Avatar answered Oct 22 '25 01:10

cronoik



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!