Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

BertModel transformers outputs string instead of tensor

I'm following this tutorial that codes a sentiment analysis classifier using BERT with the huggingface library and I'm having a very odd behavior. When trying the BERT model with a sample text I get a string instead of the hidden state. This is the code I'm using:

import transformers
from transformers import BertModel, BertTokenizer

print(transformers.__version__)

PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
PATH_OF_CACHE = "/home/mwon/data-mwon/paperChega/src_classificador/data/hugingface"

tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME,cache_dir = PATH_OF_CACHE)

sample_txt = 'When was I last outside? I am stuck at home for 2 weeks.'

encoding_sample = tokenizer.encode_plus(
  sample_txt,
  max_length=32,
  add_special_tokens=True, # Add '[CLS]' and '[SEP]'
  return_token_type_ids=False,
  padding=True,
  truncation = True,
  return_attention_mask=True,
  return_tensors='pt',  # Return PyTorch tensors
)

bert_model = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME,cache_dir = PATH_OF_CACHE)


last_hidden_state, pooled_output = bert_model(
  encoding_sample['input_ids'],
  encoding_sample['attention_mask']
)

print([last_hidden_state,pooled_output])

that outputs:

4.0.0
['last_hidden_state', 'pooler_output']
 
like image 754
Miguel Avatar asked Dec 03 '20 18:12

Miguel


2 Answers

I faced the same issue while learning how to implement Bert. I noticed that using

last_hidden_state, pooled_output = bert_model(encoding_sample['input_ids'], encoding_sample['attention_mask'])

is the issue. Use:

outputs = bert_model(encoding_sample['input_ids'], encoding_sample['attention_mask'])

and extract the last_hidden state using

output[0]

You can refer to the documentation here which tells you what is returned by the BertModel

like image 111
Aakash Bhatia Avatar answered Oct 29 '22 05:10

Aakash Bhatia


While the answer from Aakash provides a solution to the problem, it does not explain the issue. Since one of the 3.X releases of the transformers library, the models do not return tuples anymore but specific output objects:

o = bert_model(
    encoding_sample['input_ids'],
    encoding_sample['attention_mask']
)
print(type(o))
print(o.keys())

Output:

transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
odict_keys(['last_hidden_state', 'pooler_output'])

You can return to the previous behavior by adding return_dict=False to get a tuple:

o = bert_model(
   encoding_sample['input_ids'],
   encoding_sample['attention_mask'],
   return_dict=False
)

print(type(o))

Output:

<class 'tuple'>

I do not recommend that, because it is now unambiguous to select a specific part of the output without turning to the documentation as shown in the example below:

o = bert_model(encoding_sample['input_ids'],  encoding_sample['attention_mask'], return_dict=False, output_attentions=True, output_hidden_states=True)
print('I am a tuple with {} elements. You do not know what each element presents without checking the documentation'.format(len(o)))

o = bert_model(encoding_sample['input_ids'],  encoding_sample['attention_mask'], output_attentions=True, output_hidden_states=True)
print('I am a cool object and you can acces my elements with o.last_hidden_state, o["last_hidden_state"] or even o[0]. My keys are; {} '.format(o.keys()))

Output:

I am a tuple with 4 elements. You do not know what each element presents without checking the documentation
I am a cool object and you can acces my elements with o.last_hidden_state,  o["last_hidden_state"] or even o[0]. My keys are; odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states', 'attentions']) 
like image 34
cronoik Avatar answered Oct 29 '22 05:10

cronoik