I'm currently working on a PyTorch implementation of the Transformer model and had a question.
Right now, I've coded my model so that it receives source and target sentence pairs as batches. These sentences are encoded using their respective indices from a pre-made vocabulary. For example:
[[3, 2, 1, 23, 13, 50, 541, 0],
[3, 24, 13, 0, 0, 0, 0, 0],
[3, 98, 2, 4, 1, 23, 25, 4]]
where 0 are the padding indices.
My question is regarding how we should use the masking mechanism for these sentences if they're being fed in as batches. I suppose the reason why I'm confused is because I'm aware that the mask look something like:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
so that we can force our Decoder to only attend to the next sequence. Do we apply this mask iteratively to the same sentence as we run the model? For example, if we were to use the first sentence I gave above:
# Iteration 1
[3, 0, 0, 0, 0, 0, 0, 0]
# Iteration 2
[3, 2, 0, 0, 0, 0, 0, 0]
.
.
.
and so we'd obtain a prediction at each position, for each sentence in each batch?
Your decoder batch mask is a lower triangular mask (which you have), element-wise and'ed with a pad mask which is true where your caption isn't a pad value. Here's some toy code (mostly taken from https://github.com/SamLynnEvans/Transformer) for generating such a mask, with an example:
import numpy as np
import torch
def lower_triangular_mask(size):
"""
Create a lower triangular mask
"""
lt_mask = np.triu(np.ones((1, size, size)), k=1)
lt_mask = torch.from_numpy(lt_mask) == 0
return lt_mask
def create_mask(caption, pad_value):
"""
Creates the transformer decode mask
"""
# create pad mask
pad_mask = (caption != pad_value).unsqueeze(-2)
# create lower triangular mask
size = caption.size(1)
lt_mask = lower_triangular_mask(size)
# return the bitwise AND of the two masks
return pad_mask & lt_mask
if __name__ == '__main__':
torch.manual_seed(0)
# Here, we generate some random sequences, with an assigned pad value of 1
pad_value = 1
caption = torch.randint(2, 10, size=(2, 5))
caption[0, 3] = pad_value
caption[0, 4] = pad_value
print(caption)
mask = create_mask(caption, pad_value)
print(mask.size())
print(mask)
The above code returns
tensor([[6, 9, 7, 1, 1],
[5, 5, 9, 3, 5]])
torch.Size([2, 5, 5])
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False],
[ True, True, True, False, False]],
[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
The first caption, which is shorter than the second, results in a mask which doesn't let the transformer attend past the end of its sequence.
First of all you have two separate types of masking in the encoder and in the decoder.
Your encoder performs self-attention and so it needs to mask pad tokens. This way real tokens don't attend to pad tokens. If your input is of shape B x T x D, then your mask needs to be of shape B x T x T. That means that for each and every sequence in the batch you have a separate mask.
Your decoder performs causal self-attention and so for a given token it needs to mask all of the tokens that come after it. This means that the mask will be the same for all of the sequences in the batch - a simple lower-triangular matrix of shape T x T.
Here is an image showing padding masks and the causal mask.

Here is an extensive blog post that I wrote about the Transformer, you might like it.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With