Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

BertForSequenceClassification vs. BertForMultipleChoice for sentence multi-class classification

I'm working on a text classification problem (e.g. sentiment analysis), where I need to classify a text string into one of five classes.

I just started using the Huggingface Transformer package and BERT with PyTorch. What I need is a classifier with a softmax layer on top so that I can do 5-way classification. Confusingly, there seem to be two relevant options in the Transformer package: BertForSequenceClassification and BertForMultipleChoice.

Which one should I use for my 5-way classification task? What are the appropriate use cases for them?

The documentation for BertForSequenceClassification doesn't mention softmax at all, although it does mention cross-entropy. I am not sure if this class is only for 2-class classification (i.e. logistic regression).

Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

The documentation for BertForMultipleChoice mentions softmax, but the way the labels are described, it sound like this class is for multi-label classification (that is, a binary classification for multiple labels).

Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the multiple choice classification loss. Indices should be in [0, ..., num_choices] where num_choices is the size of the second dimension of the input tensors.

Thank you for any help.

like image 888
stackoverflowuser2010 Avatar asked Mar 10 '20 01:03

stackoverflowuser2010


People also ask

Can BERT be used for multiclass classification?

NLP Project for Multi Class Text Classification using BERT Model. In this NLP Project, you will learn how to build a multi-class text classification model using using the pre-trained BERT model.

What is multi class text classification?

Multi-class text classification is a text classification task with more than two classes/categories. Each data sample can be classified into one of the classes. However, a data sample cannot belong to more than one class simultaneously. For example, a model that classifies news headlines into news categories.

What is BERT sequence classification?

What is BERT? BERT or Bidirectional Encoder Representations from Transformers is a transformer-based machine learning technique for NLP. It is a pre-trained deep bidirectional representation from the unlabeled text by jointly conditioning on both left and right context.

What is BertConfig?

BertForPreTraining. < > ( config ) config (BertConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration.


1 Answers

The answer to this lies in the (admittedly very brief) description of what the tasks are about:

[BertForMultipleChoice] [...], e.g. for RocStories/SWAG tasks.

When looking at the paper for SWAG, it seems that the task is actually learning to choose from varying options. This is in contrast to your "classical" classification task, in which the "choices" (i.e., classes) do not vary across your samples, which is exactly what BertForSequenceClassification is for.

Both variants can in fact be for an arbitrary number of classes (in the case of BertForSequenceClassification), respectively choices (for BertForMultipleChoice), via changing the labels parameter in the config. But, since it seems like you are dealing with a case of "classical classification", I suggest using the BertForSequenceClassification model.

Shortly addressing the missing Softmax in BertForSequenceClassification: Since classification tasks can compute loss across classes indipendent of the sample (unlike multiple choice, where your distribution is changing), this allows you to use Cross-Entropy Loss, which factors in Softmax in the backpropagation step for increased numerical stability.

like image 67
dennlinger Avatar answered Sep 21 '22 07:09

dennlinger