Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to understand masked multi-head attention in transformer

I'm currently studying code of transformer, but I can not understand the masked multi-head of decoder. The paper said that it is to prevent you from seeing the generating word, but I can not unserstand if the words after generating word have not been generated, how can them be seen?

I try to read the code of transformer (link:https://github.com/Kyubyong/transformer). The code achieved mask is shown below. It uses the lower triangular matrix to mask, I can not understand why.

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
like image 892
Neptuner Avatar asked Sep 27 '19 02:09

Neptuner


People also ask

What is masked multi-head attention in transformer?

In Masked Multi-Head Attention Layer, attention is applied on tokens up to current position (index till which prediction is done by transformer) & not future tokens(as aren't predicted till now). This is in stark difference from Encoder where attention is calculated for the entire sequence at once.

How does multi-head attention work?

Multi-head Attention is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension.

How does attention work in Transformers?

The attention mechanism measures the similarity between the query q and each key-value ki. This similarity returns a weight for each key value. Finally, it produces an output that is the weighted combination of all the values in our database.

How do transformer masks work?

Masking is needed to prevent the attention mechanism of a transformer from “cheating” in the decoder when training (on a translating task for instance). This kind of “ cheating-proof masking” is not present in the encoder side.


2 Answers

I had the very same question after reading the Transformer paper. I found no complete and detailed answer to the question in the Internet so I'll try to explain my understanding of Masked Multi-Head Attention.

The short answer is - we need masking to make the training parallel. And the parallelization is good as it allows the model to train faster.

Here's an example explaining the idea. Let's say we train to translate "I love you" to German. The encoder works in parallel mode - it can produce vector representation of the input sequence ("I love you") within a constant number of steps (i.e. the number of steps doesn't depend on the length of the input sequence).

Let's say the encoder produces the numbers 11, 12, 13 as the vector representations of the input sequence. In reality these vectors will be much longer but for simplicity we use the short ones. Also for simplicity we ignore the service tokens, like - beginning of the sequence, - end of the sequence and others.

During the training we know that the translation should be "Ich liebe dich" (we always know the expected output during the training). Let's say the expected vector representations of the "Ich liebe dich" words are 21, 22, 23.

If we make the decoder training in sequential mode, it'll look like the training of the Recurrent Neural Network. The following sequential steps will be performed:

  • Sequential operation #1. Input: 11, 12, 13.
    • Trying to predict 21.
    • The predicted output won't be exactly 21, let's say it'll be 21.1.
  • Sequential operation #2. Input: 11, 12, 13, and also 21.1 as the previous output.
    • Trying to predict 22.
    • The predicted output won't be exactly 22, let's say it'll be 22.3.
  • Sequential operation #3. Input 11, 12, 13, and also 22.3 as the previous output.
    • Trying to predict 23.
    • The predicted output won't be exactly 23, let's say it'll be 23.5.

This means we'll need to make 3 sequential operations (in general case - a sequential operation per each input). Also we'll have an accumulating error on each next iteration. Also we don't use attention as we only look to a single previous output.

As we actually know the expected outputs we can adjust the process and make it parallel. There's no need to wait for the previous step output.

  • Parallel operation #A. Inputs: 11, 12, 13.
    • Trying to predict 21.
  • Parallel operation #B. Inputs: 11, 12, 13, and also 21.
    • Trying to predict 22.
  • Parallel operation #C. Inputs: 11, 12, 13, and also 21, 22.
    • Trying to predict 23.

This algorithm can be executed in parallel and also it doesn't accumulate the error. And this algorithm uses attention (i.e. looks to all previous inputs) thus has more information about the context to consider while making the prediction.

And here is where we need the masking. The training algorithm knows the entire expected output (21, 22, 23). It hides (masks) a part of this known output sequence for each of the parallel operations.

  • When it executes #A - it hides (masks) the entire output.
  • When it executes #B - it hides 2nd and 3rd outputs.
  • When it executes #C - it hides 3rd output.

Masking itself is implemented as the following (from the original paper):

We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections

Note: during the inference (not training) the decoder works in the sequential (not parallel) mode as it doesn't know the output sequence initially. But it's different from RNN approach as Transformer inference still uses self-attention and looks at all previous outputs (but not only the very previous one).

Note 2: I've seen in some materials that masking can be used differently for non-translation applications. For example, for language modeling the masking can be used to hide some words from the input sentence and the model will try to predict them during the training using other, non-masked words (i.e. learn to understand the context).

like image 104
artoby Avatar answered Sep 24 '22 18:09

artoby


decoder is a self-regressor and can't see the future words

  1. encoder in transformer is a self-regressor;
  2. which means it will predict the next token according to the previous;
  3. so input x can't see the future words;
  4. we use masked multi-head attention to do this.
like image 29
zhangjq Avatar answered Sep 20 '22 18:09

zhangjq