Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to efficiently mean-pool BERT embeddings while excluding padding?

Consider a batch of sentences with different lengths.

When using the BertTokenizer, I apply padding so that all the sequences have the same length and we end up with a nice tensor of shape (bs, max_seq_len).

After applying the BertModel, I get a last hidden state of shape (bs, max_seq_len, hidden_sz).

My goal is to get the mean-pooled sentence embedding for each sentence (resulting in something with shape (bs, hidden_sz)), but excluding the embeddings for the PAD tokens when taking the mean.

Is there a way to do this efficiently without looping over each sequence in the batch?

Thanks!

like image 653
Kevin Avatar asked Mar 09 '26 18:03

Kevin


1 Answers

You can pad with Nan and then use torch.nanmean. You can then change the values back to something less likely to cause gradient issues down the line.

mean_pooled = torch.nanmean(hidden_state,dim = 1)
hidden_state = torch.nan_to_num(hidden_state,nan = 0)

Alternatively, take the sum of the row and divide by the number of non-zero (assuming 0 padding) elements.

mean_pooled = torch.sum(hidden_state,dim = 1) / torch.where(hidden_state != 0, 1,0).sum(dim = 1)
like image 156
DerekG Avatar answered Mar 11 '26 09:03

DerekG



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!