Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

what the difference between att_mask and key_padding_mask in MultiHeadAttnetion

What the difference between att_mask and key_padding_mask in MultiHeadAttnetion of pytorch:

key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Thanks in advance.

like image 919
one Avatar asked Jun 29 '20 00:06

one


People also ask

What is Key_padding_mask?

The key_padding_mask is used to mask out positions that are padding, i.e., after the end of the input sequence. This is always specific to the input batch and depends on how long are the sequence in the batch compared to the longest one. It is a 2D tensor of shape batch size × input length.

What is Src_mask?

src_mask or attn_mask is a matrix used to represent which parts of the input sequence are allowed to be attended to (relative to the sequence itself). For example, in an autoregressive NN this would constitute a triangular matrix. Assuming the input sequence if of size N, this matrix would be of size N*N.


2 Answers

The key_padding_mask is used to mask out positions that are padding, i.e., after the end of the input sequence. This is always specific to the input batch and depends on how long are the sequence in the batch compared to the longest one. It is a 2D tensor of shape batch size × input length.

On the other hand, attn_mask says what key-value pairs are valid. In a Transformer decoder, a triangle mask is used to simulate the inference time and prevent the attending to the "future" positions. This is what att_mask is usually used for. If it is a 2D tensor, the shape is input length × input length. You can also have a mask that is specific to every item in a batch. In that case, you can use a 3D tensor of shape (batch size × num heads) × input length × input length. (So, in theory, you can simulate key_padding_mask with a 3D att_mask.)

like image 167
Jindřich Avatar answered Oct 20 '22 01:10

Jindřich


I think they work as the same: Both of the mask defines which attention between query and key will not be used. And the only difference between the two choices is in which shape you are more comfortable to input the mask

According to the code, it seems like the two mask are merged/taken union so they all play the same role -- which attention between query and key will not be used. As they are taken union: the two mask inputs can be different valued if it is necessary that you are using two masks, or you can input the mask in whichever mask_args according to whose required shape is convenient: Here is part of the original code from pytorch/functional.py around line 5227 in the function multi_head_attention_forward()

...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

Please correct me if you have different opinions or I am wrong.

like image 1
Flora Sun Avatar answered Oct 20 '22 00:10

Flora Sun