Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can i get all outputs of the last transformer encoder in bert pretrained model and not just the cls token output?

I'm using pytorch and this is the model from huggingface transformers link:

from transformers import BertTokenizerFast, BertForSequenceClassification
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                     num_labels=int(data['class'].nunique()),
                                                     output_attentions=False,
                                                     output_hidden_states=False)

and in the forward function I'm building, I'm calling x1, x2 = self.bert(sent_id, attention_mask=mask) Now, as far as I know, x2 is the cls output(which is the output of the first transformer encoder) but yet again, I don't think I understand the output of the model. but I want the output of all the 12 last transformer encoders. How can I do that in pytorch ?

like image 416
Alaa Grable Avatar asked Nov 17 '25 10:11

Alaa Grable


1 Answers

Ideally, if you want to look into the outputs of all the layer, you should use BertModel and not BertForSequenceClassification. Because, BertForSequenceClassification is inherited from BertModel and adds a linear layer on top of the BERT model.

from transformers import BertModel
my_bert_model = BertModel.from_pretrained("bert-base-uncased")

### Add your code to map the model to device, data to device, and obtain input_ids and mask

sequence_output, pooled_output = my_bert_model(ids, attention_mask=mask)

# sequence_output has the following shape: (batch_size, sequence_length, 768), which contains output for all tokens in the last layer of the BERT model.

sequence_output contains output for all tokens in the last layer of the BERT model.

In order to obtain the outputs of all the transformer encoder layers, you can use the following:

my_bert_model = BertModel.from_pretrained("bert-base-uncased")
sequence_output, pooled_output, all_layer_output = model(ids, attention_mask=mask, output_hidden_states=True)

all_layer_output is a output tuple containing the outputs embeddings layer + outputs of all the layer. Each element in the tuple will have a shape (batch_size, sequence_length, 768)

Hence, to get the sequence of outputs at layer-5, you can use all_layer_output[5]. As, all_layer_output[0] contains outputs of the embeddings.

like image 52
Ashwin Geet D'Sa Avatar answered Nov 19 '25 10:11

Ashwin Geet D'Sa



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!