I'm currently studying code of transformer, but I can not understand the masked multi-head of decoder. The paper said that it is to prevent you from seeing the generating word, but I can not unserstand if the words after generating word have not been generated, how can them be seen?
I try to read the code of transformer (link:https://github.com/Kyubyong/transformer). The code achieved mask is shown below. It uses the lower triangular matrix to mask, I can not understand why.
padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
In Masked Multi-Head Attention Layer, attention is applied on tokens up to current position (index till which prediction is done by transformer) & not future tokens(as aren't predicted till now). This is in stark difference from Encoder where attention is calculated for the entire sequence at once.
Multi-head Attention is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension.
The attention mechanism measures the similarity between the query q and each key-value ki. This similarity returns a weight for each key value. Finally, it produces an output that is the weighted combination of all the values in our database.
Masking is needed to prevent the attention mechanism of a transformer from “cheating” in the decoder when training (on a translating task for instance). This kind of “ cheating-proof masking” is not present in the encoder side.
I had the very same question after reading the Transformer paper. I found no complete and detailed answer to the question in the Internet so I'll try to explain my understanding of Masked Multi-Head Attention.
The short answer is - we need masking to make the training parallel. And the parallelization is good as it allows the model to train faster.
Here's an example explaining the idea. Let's say we train to translate "I love you" to German. The encoder works in parallel mode - it can produce vector representation of the input sequence ("I love you") within a constant number of steps (i.e. the number of steps doesn't depend on the length of the input sequence).
Let's say the encoder produces the numbers 11, 12, 13
as the vector representations of the input sequence. In reality these vectors will be much longer but for simplicity we use the short ones. Also for simplicity we ignore the service tokens, like - beginning of the sequence, - end of the sequence and others.
During the training we know that the translation should be "Ich liebe dich" (we always know the expected output during the training). Let's say the expected vector representations of the "Ich liebe dich" words are 21, 22, 23
.
If we make the decoder training in sequential mode, it'll look like the training of the Recurrent Neural Network. The following sequential steps will be performed:
11, 12, 13
.
21
.21
, let's say it'll be 21.1
.11, 12, 13
, and also 21.1
as the previous output.
22
.22
, let's say it'll be 22.3
.11, 12, 13
, and also 22.3
as the previous output.
23
.23
, let's say it'll be 23.5
.This means we'll need to make 3 sequential operations (in general case - a sequential operation per each input). Also we'll have an accumulating error on each next iteration. Also we don't use attention as we only look to a single previous output.
As we actually know the expected outputs we can adjust the process and make it parallel. There's no need to wait for the previous step output.
11, 12, 13
.
21
.11, 12, 13
, and also 21
.
22
.11, 12, 13
, and also 21, 22
.
23
.This algorithm can be executed in parallel and also it doesn't accumulate the error. And this algorithm uses attention (i.e. looks to all previous inputs) thus has more information about the context to consider while making the prediction.
And here is where we need the masking. The training algorithm knows the entire expected output (21, 22, 23
). It hides (masks) a part of this known output sequence for each of the parallel operations.
Masking itself is implemented as the following (from the original paper):
We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections
Note: during the inference (not training) the decoder works in the sequential (not parallel) mode as it doesn't know the output sequence initially. But it's different from RNN approach as Transformer inference still uses self-attention and looks at all previous outputs (but not only the very previous one).
Note 2: I've seen in some materials that masking can be used differently for non-translation applications. For example, for language modeling the masking can be used to hide some words from the input sentence and the model will try to predict them during the training using other, non-masked words (i.e. learn to understand the context).
decoder is a self-regressor and can't see the future words
x
can't see the future words;If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With