Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Which model (GPT2, BERT, XLNet and etc) would you use for a text classification task? Why?

I'm trying to train a model for a sentence classification task. The input is a sentence (a vector of integers) and the output is a label (0 or 1). I've seen some articles here and there about using Bert and GPT2 for text classification tasks. However, I'm not sure which one should I pick to start with. Which of these recent models in NLP such as original Transformer model, Bert, GPT2, XLNet would you use to start with? And why? I'd rather to implement in Tensorflow, but I'm flexible to go for PyTorch too. Thanks!

like image 337
khemedi Avatar asked Sep 08 '19 20:09

khemedi


3 Answers

I agree with Max's answer, but if the constraint is to use a state of the art large pretrained model, there is a really easy way to do this. The library by HuggingFace called pytorch-transformers. Whether you chose BERT, XLNet, or whatever, they're easy to swap out. Here is a detailed tutorial on using that library for text classification.

EDIT: I just came across this repo, pytorch-transformers-classification (Apache 2.0 license), which is a tool for doing exactly what you want.

like image 195
Sam H. Avatar answered Oct 17 '22 22:10

Sam H.


It highly depends on your dataset and is part of the data scientist's job to find which model is more suitable for a particular task in terms of selected performance metric, training cost, model complexity etc.

When you work on the problem you will probably test all of the above models and compare them. Which one of them to choose first? Andrew Ng in "Machine Learning Yearning" suggest starting with simple model so you can quickly iterate and test your idea, data preprocessing pipeline etc.

Don’t start off trying to design and build the perfect system. Instead, build and train a basic system quickly—perhaps in just a few days

According to this suggestion, you can start with a simpler model such as ULMFiT as a baseline, verify your ideas and then move on to more complex models and see how they can improve your results.

Note that modern NLP models contain a large number of parameters and it is difficult to train them from scratch without a large dataset. That's why you may want to use transfer learning: you can download pre-trained model and use it as a basis and fine-tune it to your task-specific dataset to achieve better performance and reduce training time.

like image 38
Max Avatar answered Oct 17 '22 22:10

Max


Well like others mentioned, it depends on the dataset and multiple models should be tried and best one must be chosen.

However, sharing my experience, XLNet beats all other models so far by a good margin. Hence if learning is not the objective, i would simple start with XLNET and then try a few more down the line and conclude. It just saves time in exploring.

Below repo is excellent to do all this quickly. Kudos to them.

https://github.com/microsoft/nlp-recipes

It uses hugging face transformers and makes them dead simple. 😃

like image 2
Narahari B M Avatar answered Oct 17 '22 23:10

Narahari B M