Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - applying attention efficiently

I have build a RNN language model with attention and I am creating context vector for every element of the input by attending all the previous hidden states (only one direction).

The most straight forward solution in my opinion is using a for-loop over the RNN output, such that each context vector is computed one after another.

import torch
import torch.nn as nn
import torch.nn.functional as F

class RNN_LM(nn.Module):
    def __init__(self, hidden_size, vocab_size, embedding_dim=None, droprate=0.5):
        super().__init__()
        if not embedding_dim:
            embedding_dim = hidden_size
        self.embedding_matrix = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=False)
        self.attn = nn.Linear(hidden_size, hidden_size)
        self.vocab_dist = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(droprate)

    def forward(self, x):
        x = self.dropout(self.embedding_matrix(x.view(-1, 1)))
        x, states = self.lstm(x)
        #print(x.size())
        x = x.squeeze()
        content_vectors = [x[0].view(1, -1)]
        # for-loop over hidden states and attention
        for i in range(1, x.size(0)):
            prev_states = x[:i]
            current_state = x[i].view(1, -1)

            attn_prod = torch.mm(self.attn(current_state), prev_states.t())
            attn_weights = F.softmax(attn_prod, dim=1)
            context = torch.mm(attn_weights, prev_states)
            content_vectors.append(context)

        return self.vocab_dist(self.dropout(torch.cat(content_vectors)))

Note: The forward method here is only used for training.

However this solution is not very efficient as the code is not well parallelizable with computing each context vector sequently. But since the context vectors are not dependent on each other, I wonder if there is a non-sequential way of calculating them.

So is there is a way to compute the context vectors without for-loop so that more of computation can be parallelized?

like image 892
MBT Avatar asked Dec 10 '18 13:12

MBT


1 Answers

Ok, for clarity: I assume we only really care about vectorizing the for loop. What is the shape of x? Assuming x is 2-dimensional, I have the following code, where v1 executes your loop and v2 is a vectorized version:

import torch
import torch.nn.functional as F

torch.manual_seed(0)

x = torch.randn(3, 6)

def v1():
    for i in range(1, x.size(0)):
        prev = x[:i]
        curr = x[i].view(1, -1)

        prod = torch.mm(curr, prev.t())
        attn = prod # same shape
        context = torch.mm(attn, prev)
        print(context)

def v2():
    # we're going to unroll the loop by vectorizing over the new,
    # 0-th dimension of `x`. We repeat it as many times as there
    # are iterations in the for loop
    repeated = x.unsqueeze(0).repeat(x.size(0), 1, 1)

    # we're looking to build a `prevs` tensor such that
    # prevs[i, x, y] == prev[x, y] at i-th iteration of the loop in v1,
    # up to 0-padding necessary to make them all the same size.
    # We need to build a higher-dimensional equivalent of torch.triu
    xs = torch.arange(x.size(0)).reshape(1, -1, 1)
    zs = torch.arange(x.size(0)).reshape(-1, 1, 1)
    prevs = torch.where(zs < xs, torch.tensor(0.), repeated)

    # this is an equivalent of the above iteration starting at 1
    prevs = prevs[:-1]
    currs = x[1:]

    # a batched matrix multiplication
    prod = torch.matmul(currs, prevs.transpose(1, 2))
    attn = prod # same shape
    context = torch.matmul(attn, prevs)
    # equivalent of a higher dimensional torch.diagonal
    contexts = torch.einsum('iij->ij', (context))
    print(contexts)

print(x)

print('\n------ v1 -------\n')
v1()
print('\n------ v2 -------\n')
v2()

which vectorizes your loop, with some caveats. First, I assume x is 2-dimensional. Secondly, I skip taking the softmax claiming it doesn't change the size of the input and thus doesn't affect vectorization. That's a true, but unfortunately softmax of a 0-padded vector v is not equal to a 0-padded softmax of unpadded v. This can be fixed with renormalization though. Please let me know if my assumptions are correct and whether this is a good enough starting point for your work.

like image 139
Jatentaki Avatar answered Nov 14 '22 20:11

Jatentaki