Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why do we do batch matrix-matrix product?

I'm following Pytorch seq2seq tutorial and ittorch.bmm method is used like below:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

I understand why we need to multiply attention weight and encoder outputs.

What I don't quite understand is the reason why we need bmm method here. torch.bmm document says

Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.

batch1 and batch2 must be 3-D tensors each containing the same number of matrices.

If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.

enter image description here

like image 873
aerin Avatar asked Jun 12 '18 22:06

aerin


People also ask

What is matrix multiplication used for?

Matrix multiplication is probably the most important matrix operation. It is used widely in such areas as network theory, solution of linear systems of equations, transformation of co-ordinate systems, and population modeling, to name but a very few.

What is the meaning of matrix product?

Matrix Multiplication Definition. Matrix multiplication, also known as matrix product and the multiplication of two matrices, produces a single matrix. It is a type of binary operation. If A and B are the two matrices, then the product of the two matrices A and B are denoted by: X = AB.

How does torch BMM work?

Torch.bmm()Matrix multiplication is carried out between the tensor of m*n and n*p size. Matrix multiplication is carried out between the matrices of size (b * n * m) and (b * m * p) where b is the size of the batch. It is only used for matrix multiplication where both matrices are 2 dimensional.

Why do we multiply matrices row by column?

Rows come first, so first matrix provides row numbers. Columns come second, so second matrix provide column numbers. Matrix multiplication is really just a way of organizing vectors we want to find the dot product of.


2 Answers

In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h where h is the hidden state size of the encoder (which is an RNN).

Now while decoding (during training) the input sequences are given one at a time, so the input is B x 1 x d and the decoder produces a tensor of shape B x 1 x h. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.

So, consider you have two tensors of shape T1 = B x S x h and T2 = B x 1 x h. So if you can do batch matrix multiplication as follows.

out = torch.bmm(T1, T2.transpose(1, 2))

Essentially you are multiplying a tensor of shape B x S x h with a tensor of shape B x h x 1 and it will result in B x S x 1 which is the attention weight for each batch.

Here, the attention weights B x S x 1 represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h by transposing first and it will result in a tensor of shape B x h x 1. And if you perform squeeze at dim=2, you will get a tensor of shape B x h which is your context vector.

This context vector (B x h) is usually concatenated to decoder's hidden state (B x 1 x h, squeeze dim=1) to predict the next token.

like image 133
Wasi Ahmad Avatar answered Oct 04 '22 15:10

Wasi Ahmad


while @wasiahmad is right about the general implementation of seq2seq, in the mentioned tutorial there's no batch (B=1), and the bmm is just over-engineering and can be safely replaced with matmul with the exact same model quality and performance. See for yourself, replace this:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

with this:

        attn_applied = torch.matmul(attn_weights,
                                 encoder_outputs)
        output = torch.cat((embedded[0], attn_applied), 1)

and run the notebook.


Also, note that while @wasiahmad talks about the encoder input as B x S x d, in pytorch 1.7.0, the GRU which is the main engine of the encoder expects an input format of (seq_len, batch, input_size) by default. If you want to work with @wasiahmad format, pass the batch_first = True flag.

like image 34
ihadanny Avatar answered Oct 04 '22 14:10

ihadanny