Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does the mask work in the Transformer if it receives a batch of different sentences as input?

I'm currently working on a PyTorch implementation of the Transformer model and had a question.

Right now, I've coded my model so that it receives source and target sentence pairs as batches. These sentences are encoded using their respective indices from a pre-made vocabulary. For example:

[[3,  2,  1, 23, 13, 50, 541, 0],
 [3, 24, 13,  0,  0,  0,   0, 0],
 [3, 98,  2,  4,  1,  23, 25, 4]]

where 0 are the padding indices.

My question is regarding how we should use the masking mechanism for these sentences if they're being fed in as batches. I suppose the reason why I'm confused is because I'm aware that the mask look something like:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

so that we can force our Decoder to only attend to the next sequence. Do we apply this mask iteratively to the same sentence as we run the model? For example, if we were to use the first sentence I gave above:

# Iteration 1
[3, 0, 0, 0, 0, 0, 0, 0]

# Iteration 2
[3, 2, 0, 0, 0, 0, 0, 0]

.
.
.

and so we'd obtain a prediction at each position, for each sentence in each batch?

like image 836
Sean Avatar asked Oct 17 '25 20:10

Sean


2 Answers

Your decoder batch mask is a lower triangular mask (which you have), element-wise and'ed with a pad mask which is true where your caption isn't a pad value. Here's some toy code (mostly taken from https://github.com/SamLynnEvans/Transformer) for generating such a mask, with an example:

import numpy as np
import torch

def lower_triangular_mask(size):
    """
    Create a lower triangular mask
    """

    lt_mask = np.triu(np.ones((1, size, size)), k=1)
    lt_mask = torch.from_numpy(lt_mask) == 0

    return lt_mask

def create_mask(caption, pad_value):
    """
    Creates the transformer decode mask
    """

    # create pad mask
    pad_mask = (caption != pad_value).unsqueeze(-2)

    # create lower triangular mask
    size = caption.size(1)
    lt_mask = lower_triangular_mask(size)

    # return the bitwise AND of the two masks
    return pad_mask & lt_mask

if __name__ == '__main__':
    torch.manual_seed(0)

    # Here, we generate some random sequences, with an assigned pad value of 1
    pad_value = 1
    caption = torch.randint(2, 10, size=(2, 5))
    caption[0, 3] = pad_value
    caption[0, 4] = pad_value

    print(caption)

    mask = create_mask(caption, pad_value)
    print(mask.size())
    print(mask)

The above code returns

tensor([[6, 9, 7, 1, 1],
        [5, 5, 9, 3, 5]])
torch.Size([2, 5, 5])
tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True, False, False]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

The first caption, which is shorter than the second, results in a mask which doesn't let the transformer attend past the end of its sequence.

like image 180
daveboat Avatar answered Oct 21 '25 04:10

daveboat


First of all you have two separate types of masking in the encoder and in the decoder.

Your encoder performs self-attention and so it needs to mask pad tokens. This way real tokens don't attend to pad tokens. If your input is of shape B x T x D, then your mask needs to be of shape B x T x T. That means that for each and every sequence in the batch you have a separate mask.

Your decoder performs causal self-attention and so for a given token it needs to mask all of the tokens that come after it. This means that the mask will be the same for all of the sequences in the batch - a simple lower-triangular matrix of shape T x T.

Here is an image showing padding masks and the causal mask.

mask

Here is an extensive blog post that I wrote about the Transformer, you might like it.

like image 24
pi-tau Avatar answered Oct 21 '25 05:10

pi-tau