Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Transformers for text classification?

I have two questions about how to use Tensorflow implementation of the Transformers for text classifications.

  • First, it seems people mostly used only the encoder layer to do the text classification task. However, encoder layer generates one prediction for each input word. Based on my understanding of transformers, the input to the encoder each time is one word from the input sentence. Then, the attention weights and the output is calculated using the current input word. And we can repeat this process for all of the words in the input sentence. As a result we'll end up with pairs of (attention weights, outputs) for each word in the input sentence. Is that correct? Then how would you use this pairs to perform a text classification?
  • Second, based on the Tensorflow implementation of transformer here, they embed the whole input sentence to one vector and feed a batch of these vectors to the Transformer. However, I expected the input to be a batch of words instead of sentences based on what I've learned from The Illustrated Transformer

Thank you!

like image 913
khemedi Avatar asked Sep 26 '19 19:09

khemedi


People also ask

Can transformers be used for text classification?

The transformer model is able to perform quite well in the task of text classification as we are able to achieve the desired results on most of our predictions.

Which transformer model is best for text classification?

For tasks in which the text classes are relatively few, the best-performing text classification systems use pretrained Transformer models such as BERT, XLNet, and RoBERTa. But Transformer-based models scale quadratically with the input sequence length and linearly with the number of classes.

Can transformer be used for classification?

Transformers can be used for classification tasks. I found a good tutorial where they used a BERT Transformer for the encoding and a Convolutional Neural Network for a sentiment analysis.


2 Answers

There are two approaches, you can take:

  1. Just average the states you get from the encoder;
  2. Prepend a special token [CLS] (or whatever you like to call it) and use the hidden state for the special token as input to your classifier.

The second approach is used by BERT. When pre-training, the hidden state corresponding to this special token is used for predicting whether two sentences are consecutive. In the downstream tasks, it is also used for sentence classification. However, my experience is that sometimes, averaging the hidden states give a better result.

Instead of training a Transformer model from scratch, it is probably more convenient to use (and eventually finetune) a pre-trained model (BERT, XLNet, DistilBERT, ...) from the transformers package. It has pre-trained models ready to use in PyTorch and TensorFlow 2.0.

like image 79
Jindřich Avatar answered Nov 03 '22 03:11

Jindřich


  1. The Transformers are designed to take the whole input sentence at once. The main motive for designing a transformer was to enable parallel processing of the words in the sentences. This parallel processing is not possible in LSTMs or RNNs or GRUs as they take words of the input sentence as input one by one. So in the encoder part of the transformers, the very first layer contains the number of units equal to the number of words in a sentence and then each unit converts that word into an embedding vector corresponding to that word. Further, the rest of the processes are carried out. For more details, you can go through the article: http://jalammar.github.io/illustrated-transformer/ How to use this transformer for text classification - Since in text classification our output is a single number not a sequence of numbers or vectors so we can remove the decoder part and just use the encoder part. The output of the encoder is a set of vectors, the same in number as the number of words in the input sentence. Further, we can feed these sets of output vectors into a CNN, or we can add an LSTM or RNN model and perform classification.
  2. The input is the whole sentence or batch of sentences not word by word. Surely you would have misunderstood it.
like image 33
Khobaib Alam Avatar answered Nov 03 '22 03:11

Khobaib Alam