I have build a RNN language model with attention and I am creating context vector for every element of the input by attending all the previous hidden states (only one direction).
The most straight forward solution in my opinion is using a for-loop over the RNN output, such that each context vector is computed one after another.
import torch
import torch.nn as nn
import torch.nn.functional as F
class RNN_LM(nn.Module):
def __init__(self, hidden_size, vocab_size, embedding_dim=None, droprate=0.5):
super().__init__()
if not embedding_dim:
embedding_dim = hidden_size
self.embedding_matrix = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=False)
self.attn = nn.Linear(hidden_size, hidden_size)
self.vocab_dist = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(droprate)
def forward(self, x):
x = self.dropout(self.embedding_matrix(x.view(-1, 1)))
x, states = self.lstm(x)
#print(x.size())
x = x.squeeze()
content_vectors = [x[0].view(1, -1)]
# for-loop over hidden states and attention
for i in range(1, x.size(0)):
prev_states = x[:i]
current_state = x[i].view(1, -1)
attn_prod = torch.mm(self.attn(current_state), prev_states.t())
attn_weights = F.softmax(attn_prod, dim=1)
context = torch.mm(attn_weights, prev_states)
content_vectors.append(context)
return self.vocab_dist(self.dropout(torch.cat(content_vectors)))
Note: The forward
method here is only used for training.
However this solution is not very efficient as the code is not well parallelizable with computing each context vector sequently. But since the context vectors are not dependent on each other, I wonder if there is a non-sequential way of calculating them.
So is there is a way to compute the context vectors without for-loop so that more of computation can be parallelized?
Ok, for clarity: I assume we only really care about vectorizing the for
loop. What is the shape of x
? Assuming x
is 2-dimensional, I have the following code, where v1
executes your loop and v2
is a vectorized version:
import torch
import torch.nn.functional as F
torch.manual_seed(0)
x = torch.randn(3, 6)
def v1():
for i in range(1, x.size(0)):
prev = x[:i]
curr = x[i].view(1, -1)
prod = torch.mm(curr, prev.t())
attn = prod # same shape
context = torch.mm(attn, prev)
print(context)
def v2():
# we're going to unroll the loop by vectorizing over the new,
# 0-th dimension of `x`. We repeat it as many times as there
# are iterations in the for loop
repeated = x.unsqueeze(0).repeat(x.size(0), 1, 1)
# we're looking to build a `prevs` tensor such that
# prevs[i, x, y] == prev[x, y] at i-th iteration of the loop in v1,
# up to 0-padding necessary to make them all the same size.
# We need to build a higher-dimensional equivalent of torch.triu
xs = torch.arange(x.size(0)).reshape(1, -1, 1)
zs = torch.arange(x.size(0)).reshape(-1, 1, 1)
prevs = torch.where(zs < xs, torch.tensor(0.), repeated)
# this is an equivalent of the above iteration starting at 1
prevs = prevs[:-1]
currs = x[1:]
# a batched matrix multiplication
prod = torch.matmul(currs, prevs.transpose(1, 2))
attn = prod # same shape
context = torch.matmul(attn, prevs)
# equivalent of a higher dimensional torch.diagonal
contexts = torch.einsum('iij->ij', (context))
print(contexts)
print(x)
print('\n------ v1 -------\n')
v1()
print('\n------ v2 -------\n')
v2()
which vectorizes your loop, with some caveats. First, I assume x
is 2-dimensional. Secondly, I skip taking the softmax
claiming it doesn't change the size of the input and thus doesn't affect vectorization. That's a true, but unfortunately softmax of a 0-padded vector v
is not equal to a 0-padded softmax of unpadded v
. This can be fixed with renormalization though. Please let me know if my assumptions are correct and whether this is a good enough starting point for your work.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With