I am having a difficult time in understanding transformers. Everything is getting clear bit by bit but one thing that makes my head scratch is what is the difference between src_mask and src_key_padding_mask which is passed as an argument in forward function in both encoder layer and decoder layer.
https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer
The general thing is to notice the difference between the use of the tensors _mask
vs _key_padding_mask
.
Inside the transformer when attention is done we usually get an squared intermediate tensor with all the comparisons
of size [Tx, Tx]
(for the input to the encoder), [Ty, Ty]
(for the shifted output - one of the inputs to the decoder)
and [Ty, Tx]
(for the memory mask - the attention between output of encoder/memory and input to decoder/shifted output).
So we get that this are the uses for each of the masks in the transformer
(note the notation from the pytorch docs is as follows where Tx=S is the source sequence length
(e.g. max of input batches),
Ty=T is the target sequence length
(e.g. max of target length),
B=N is the batch size
,
D=E is the feature number
):
src_mask [Tx, Tx] = [S, S]
– the additive mask for the src sequence (optional).
This is applied when doing atten_src + src_mask
. I'm not sure of an example input - see tgt_mask for an example
but the typical use is to add -inf
so one could mask the src_attention that way if desired.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
tgt_mask [Ty, Ty] = [T, T]
– the additive mask for the tgt sequence (optional).
This is applied when doing atten_tgt + tgt_mask
. An example use is the diagonal to avoid the decoder from cheating.
So the tgt is right shifted, the first tokens are start of sequence token embedding SOS/BOS and thus the first
entry is zero while the remaining. See concrete example at the appendix.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
memory_mask [Ty, Tx] = [T, S]
– the additive mask for the encoder output (optional).
This is applied when doing atten_memory + memory_mask
.
Not sure of an example use but as previously, adding -inf
sets some of the attention weight to zero.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
src_key_padding_mask [B, Tx] = [N, S]
– the ByteTensor mask for src keys per batch (optional).
Since your src usually has different lengths sequences it's common to remove the padding vectors
you appended at the end.
For this you specify the length of each sequence per example in your batch.
See concrete example in appendix.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
tgt_key_padding_mask [B, Ty] = [N, t]
– the ByteTensor mask for tgt keys per batch (optional).
Same as previous.
See concrete example in appendix.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
memory_key_padding_mask [B, Tx] = [N, S]
– the ByteTensor mask for memory keys per batch (optional).
Same as previous.
See concrete example in appendix.
If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
Examples from pytorch tutorial (https://pytorch.org/tutorials/beginner/translation_transformer.html):
src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)
returns a tensor of booleans of size [Tx, Tx]
:
tensor([[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False]])
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)
mask = mask.transpose(0, 1).float()
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, float(0.0))
generates the diagonal for the right shifted output which the input to the decoder.
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, -inf],
...,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.]])
usually the right shifted output has the BOS/SOS at the beginning and it's the tutorial gets the right shift simply
by appending that BOS/SOS at the front and then triming the last element with tgt_input = tgt[:-1, :]
.
The padding is just to mask the padding at the end. The src padding is usually the same as the memory padding. The tgt has it's own sequences and thus it's own padding. Example:
src_padding_mask = (src == PAD_IDX).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
memory_padding_mask = src_padding_mask
Output:
tensor([[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., True, True, True]])
note that a False
means there is no padding token there (so yes use that value in the transformer forward pass) and a True
means that there is a padding token (so masked it out so the transformer pass forward does not get affected).
The answers are sort of spread around but I found only these 3 references being useful (the separate layers docs/stuff wasn't very useful honesty):
I must say PyTorch implementations are a bit confusing as it contains too many mask parameters. But I can shed light on the two mask parameters that you are referring to. Both src_mask
and src_key_padding_mask
is used in the MultiheadAttention
mechanism. According to the documentation of MultiheadAttention:
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention.
attn_mask – 2D or 3D mask that prevents attention to certain positions.
As you know from the paper, Attention is all you need, MultiheadAttention is used in both Encoder and Decoder. However, in Decoder, there are two types of MultiheadAttention. One is called Masked MultiheadAttention
and another one is the regular MultiheadAttention
. To accommodate both these techniques, PyTorch uses the above mentioned two parameters in their MultiheadAttention implementation.
So, long story short-
attn_mask
and key_padding_mask
is used in Encoder's MultiheadAttention
and Decoder's Masked MultiheadAttention
.memory_mask
is used in Decoder's MultiheadAttention
mechanism as pointed out here.Looking into the implementation of MultiheadAttention might help you.
As you can see from here and here, first src_mask
is used to block specific positions from attending and then key_padding_mask
is used to block attending to pad tokens.
Note. Answer updated based on @michael-jungo's comment.
To give a small example, consider I want to build a sequential recommender i.e., given the items the users have purchased till time 't' predict the next item at 't+1'
u1 - [i1, i2, i7]
u2 - [i2, i5]
u3 - [i6, i7, i1, i2]
For this task, I could use a transformer where I would make the sequence equal length by padding it with 0's on left.
u1 - [0, i1, i2, i7]
u2 - [0, 0, i2, i5]
u3 - [i6, i7, i1, i2]
I will use key_padding_mask to tell PyTorch that 0's shd be ignored.
Now, consider user u3
where given [i6]
I want to predict [i7]
and later given [i6, i7]
I want to predict [i1]
i.e., I want causal attention, such that the attention doesn't peep into the future elements. For this, I will use attn_mask. Hence for user u3
attn_mask will be like
[[True, False, False, False],
[True, True , False, False],
[True, True , True , False]
[True, True , True , True ]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With