Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Text generation using huggingface's distilbert models

I've been struggling with huggingface's DistilBERT model for some time now, since the documentation seems very unclear and their examples (e.g. https://github.com/huggingface/transformers/blob/master/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb and https://github.com/huggingface/transformers/tree/master/examples/distillation) are extremely thick and the thing they are showcasing doesn't seem well documented.

I'm wondering if anyone here has any experience and knows of some good code example for basic in-python usage of their models. Namely:

  • How to properly decode the output of the model into actual text (no matter how I change its shape the tokenizer seems willing to decode it and always yields some sequence of [UNK] tokens)

  • How to actually use their schedulers+optimizers to train a model for a simple text to text task.

like image 676
George Avatar asked Dec 08 '19 22:12

George


1 Answers

To decode the output, you can do

        prediction_as_text = tokenizer.decode(output_ids, skip_special_tokens=True)

output_ids contains the generated token ids. It can also be a batch (output ids at every row), then the prediction_as_text will also be a 2D array containing text at every row. skip_special_tokens=True filters out the special tokens used in the training such as (end of sentence), (start of sentence), etc. These special tokens vary from model to model of course but almost every model has such special tokens used during training and inference.

There is not an easy way to get rid of unknown tokens[UNK]. The models have limited vocabulary. If a model encounters a subword that is not in its in vocabulary, it is replaced by a special unknown token and the model is trained with these tokens. So, it also learn to generate [UNK]. There are various way to deal with it such as replacing it with the second-highest probable token, or using beam search and taking the most probable sentence that do not contain any unknown tokens. However, if you really want to get rid of these, you should rather use a model that uses Byte Pair Encoding. It solves the problem of unknown words completely. As you can read in this link, Bert and DistilBert uses subwork tokenization and have such a limitation. https://huggingface.co/transformers/tokenizer_summary.html

To use the schedulers and optimizers, you should use the class Trainer and TrainingArguments. Below I posted an example from one of my own projects.

    output_dir=model_directory,
    num_train_epochs=args.epochs,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    warmup_steps=500,
    weight_decay=args.weight_decay,
    logging_dir=model_directory,
    logging_steps=100,
    do_eval=True,
    evaluation_strategy='epoch',
    learning_rate=args.learning_rate,
    load_best_model_at_end=True, # the last checkpoint is the best model wrt metric_for_best_model
    metric_for_best_model='eval_loss',
    lr_scheduler_type = 'linear'
    greater_is_better=False, 
    save_total_limit=args.epochs if args.save_total_limit == -1 else args.save_total_limit,

)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    optimizers=[torch.optim.Adam(params=model.parameters(), 
    lr=args.learning_rate), None], // optimizers
    tokenizer=tokenizer,
)

For other scheduler types, see this link: https://huggingface.co/transformers/main_classes/optimizer_schedules.html

like image 124
Berkay Berabi Avatar answered Dec 11 '22 02:12

Berkay Berabi