Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract document embeddings from HuggingFace Longformer

Looking to do something similar to

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

(from this thread) using the longformer

the documentation example seems to do something similar, but is confusing (esp. wrt. how to set the attention mask, I assume I'd want to set it to the [CLS] token, the example sets global attention to random values I think)

>>> import torch
>>> from transformers import LongformerModel, LongformerTokenizer

>>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096', return_dict=True)
>>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

>>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000)  # long input document
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0)  # batch of size 1

>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
>>> attention_mask[:, [1, 4, 21,]] = 2  # Set global attention based on the task. For example,
...                                     # classification: the <s> token
...                                     # QA: question tokens
...                                     # LM: potentially on the beginning of sentences and paragraphs
>>> outputs = model(input_ids, attention_mask=attention_mask)
>>> sequence_output = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output

(from here)

like image 717
Maxim Khesin Avatar asked Nov 06 '22 05:11

Maxim Khesin


1 Answers

You wouldn't need to mess with those values (unless you want to optimize the way longformer attends to different tokens). In the example you've listed above it will enforce global attention to just the 1st, 4th and 21st token. They've put random numbers here but sometimes you might want to globally attend for a certain type of tokens such as the question tokens in a sequence of tokens (ex: <question tokens> + <answer tokens> but only globally attend the first part).

If you're looking for just embeddings you can follow what's been discussed here :The last layers of longformer for document embeddings.

like image 175
Ramesh Arvind Avatar answered Nov 30 '22 15:11

Ramesh Arvind