Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make a Trainer pad inputs in a batch with huggingface-transformers?

I'm trying to train a model using a Trainer, according to the documentation (https://huggingface.co/transformers/master/main_classes/trainer.html#transformers.Trainer) I can specify a tokenizer:

tokenizer (PreTrainedTokenizerBase, optional) – The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model.

So padding should be handled automatically, but when trying to run it I get this error:

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length.

The tokenizer is created this way:

tokenizer = BertTokenizerFast.from_pretrained(pretrained_model)

And the Trainer like that:

trainer = Trainer(
    tokenizer=tokenizer,
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=dev,
    compute_metrics=compute_metrics
)

I've tried putting the padding and truncation parameters in the tokenizer, in the Trainer, and in the training_args. Nothing does. Any idea?

like image 320
François MENTEC Avatar asked Sep 24 '20 13:09

François MENTEC


1 Answers

Look at the columns your tokenizer is returning. You might wanna limit it to only the required columns.

For Example

def preprocess_function(examples):
#function to tokenize the dataset.
if sentence2_key is None:
    return tokenizer(examples[sentence1_key], truncation=True,padding=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True,padding=True)


encoded_dataset = dataset.map(preprocess_function, 
batched=True,load_from_cache_file=False)


#Thing you should do is 

columns_to_return = ['input_ids', 'label', 'attention_mask']
encoded_dataset.set_format(type='torch', columns=columns_to_return)

Hope it helps.

like image 57
Animesh Seemendra Avatar answered Oct 05 '22 05:10

Animesh Seemendra