Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why Bert transformer uses [CLS] token for classification instead of average over all tokens?

Tags:

I am doing experiments on bert architecture and found out that most of the fine-tuning task takes the final hidden layer as text representation and later they pass it to other models for the further downstream task.

Bert's last layer looks like this :

enter image description here

Where we take the [CLS] token of each sentence :

enter image description here

Image source

I went through many discussion on this huggingface issue, datascience forum question, github issue Most of the data scientist gives this explanation :

BERT is bidirectional, the [CLS] is encoded including all representative information of all tokens through the multi-layer encoding procedure. The representation of [CLS] is individual in different sentences.

My question is, Why the author ignored the other information ( each token's vector ) and taking the average, max_pool or other methods to make use of all information rather than using [CLS] token for classification?

How does this [CLS] token help compare to the average of all token vectors?

like image 579
Aaditya Ura Avatar asked Jul 02 '20 21:07

Aaditya Ura


People also ask

Why CLS token is used in BERT?

[CLS] is a special classification token and the last hidden state of BERT corresponding to this token (h[CLS]) is used for classification tasks. BERT uses Wordpiece embeddings input for tokens. Along with token embeddings, BERT uses positional embeddings and segment embeddings for each token.

What does CLS mean in BERT?

CLS stands for classification and its there to represent sentence-level classification. In short in order to make pooling scheme of BERT work this tag was introduced.

What are CLS and SEP tags in BERT embedding?

The [CLS] and [SEP] Tokens In BERT, the decision is that the hidden state of the first token is taken to represent the whole sentence. To achieve this, an additional token has to be added manually to the input sentence. In the original implementation, the token [CLS] is chosen for this purpose.

How is CLS calculated BERT?

The [CLS] vector gets computed using self-attention (like everything in BERT), so it can only collect the relevant information from the rest of the hidden states. So, in some sense the [CLS] vector is also an average over token vectors, only more cleverly computed, specifically for the tasks that you fine-tune on.


2 Answers

BERT is designed primarily for transfer learning, i.e., finetuning on task-specific datasets. If you average the states, every state is averaged with the same weight: including stop words or other stuff that are not relevant for the task. The [CLS] vector gets computed using self-attention (like everything in BERT), so it can only collect the relevant information from the rest of the hidden states. So, in some sense the [CLS] vector is also an average over token vectors, only more cleverly computed, specifically for the tasks that you fine-tune on.

Also, my experience is that when I keep the weights fixed and do not fine-tune BERT, using the token average yields better results.

like image 182
Jindřich Avatar answered Sep 19 '22 15:09

Jindřich


The use of the [CLS] token to represent the entire sentence comes from the original BERT paper, section 3:

The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks.

Your intuition is correct that averaging the vectors of all the tokens may produce superior results. In fact, that is exactly what is mentioned in the Huggingface documentation for BertModel:

Returns

pooler_output (torch.FloatTensor: of shape (batch_size, hidden_size)):

Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training.

This output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.

Update: Huggingface removed that statement ("This output is usually not a good summary of the semantic content ...") in v3.1.0. You'll have to ask them why.

like image 20
stackoverflowuser2010 Avatar answered Sep 22 '22 15:09

stackoverflowuser2010