In the last few layers of sequence classification by HuggingFace, they took the first hidden state of the sequence length of the transformer output to be used for classification.
hidden_state = distilbert_output[0] # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0] # (bs, dim) <-- first hidden state
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
Is there any benefit to taking the first hidden state over the last, average, or even the use of a Flatten layer instead?
12-layer, 768-hidden, 12-heads, 110M parameters. The multilingual DistilBERT model distilled from the Multilingual BERT model bert-base-multilingual-cased checkpoint.
Flowing Through DistilBERT The output would be a vector for each input token. each vector is made up of 768 numbers (floats).
DistilBERT uses a technique called distillation, which approximates the Google's BERT, i.e. the large neural network by a smaller one. The idea is that once a large neural network has been trained, its full output distributions can be approximated using a smaller network.
DistilBERT pretrained on the same data as BERT, which is BookCorpus, a dataset consisting of 11,038 unpublished books and English Wikipedia (excluding lists, tables and headers).
Yes, this is directly related to the way that BERT is trained. Specifically, I encourage you to have a look at the original BERT paper, in which the authors introduce the meaning of the [CLS]
token:
[CLS]
is a special symbol added in front of every input example [...].
Specifically, it is used for classification purposes, and therefore the first and simplest choice for any fine-tuning for classification tasks. What your relevant code fragment is doing, is basically just extracting this [CLS]
token.
Unfortunately, the DistilBERT documentation of Huggingface's library does not explicitly refer to this, but you rather have to check out their BERT documentation, where they also highlight some issues with the [CLS]
token, analogous to your concerns:
Alongside MLM, BERT was trained using a next sentence prediction (NSP) objective using the [CLS] token as a sequence approximate. The user may use this token (the first token in a sequence built with special tokens) to get a sequence prediction rather than a token prediction. However, averaging over the sequence may yield better results than using the [CLS] token.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With