Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

why take the first hidden state for sequence classification (DistilBertForSequenceClassification) by HuggingFace

In the last few layers of sequence classification by HuggingFace, they took the first hidden state of the sequence length of the transformer output to be used for classification.

hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

Is there any benefit to taking the first hidden state over the last, average, or even the use of a Flatten layer instead?

like image 607
doe Avatar asked Feb 06 '20 04:02

doe


People also ask

How many layers does a DistilBERT have?

12-layer, 768-hidden, 12-heads, 110M parameters. The multilingual DistilBERT model distilled from the Multilingual BERT model bert-base-multilingual-cased checkpoint.

What is the output of DistilBERT?

Flowing Through DistilBERT The output would be a vector for each input token. each vector is made up of 768 numbers (floats).

How does DistilBERT work?

DistilBERT uses a technique called distillation, which approximates the Google's BERT, i.e. the large neural network by a smaller one. The idea is that once a large neural network has been trained, its full output distributions can be approximated using a smaller network.

What is DistilBERT Pretrained?

DistilBERT pretrained on the same data as BERT, which is BookCorpus, a dataset consisting of 11,038 unpublished books and English Wikipedia (excluding lists, tables and headers).


1 Answers

Yes, this is directly related to the way that BERT is trained. Specifically, I encourage you to have a look at the original BERT paper, in which the authors introduce the meaning of the [CLS] token:

[CLS] is a special symbol added in front of every input example [...].

Specifically, it is used for classification purposes, and therefore the first and simplest choice for any fine-tuning for classification tasks. What your relevant code fragment is doing, is basically just extracting this [CLS] token.

Unfortunately, the DistilBERT documentation of Huggingface's library does not explicitly refer to this, but you rather have to check out their BERT documentation, where they also highlight some issues with the [CLS] token, analogous to your concerns:

Alongside MLM, BERT was trained using a next sentence prediction (NSP) objective using the [CLS] token as a sequence approximate. The user may use this token (the first token in a sequence built with special tokens) to get a sequence prediction rather than a token prediction. However, averaging over the sequence may yield better results than using the [CLS] token.

like image 77
dennlinger Avatar answered Oct 07 '22 08:10

dennlinger