I am trying to implement a seq2seq model in Pytorch and I am having some problem with the batching. For example I have a batch of data whose dimensions are
[batch_size, sequence_lengths, encoding_dimension]
where the sequence lengths are different for each example in the batch.
Now, I managed to do the encoding part by padding each element in the batch to the length of the longest sequence.
This way if I give as input to my net a batch with the same shape as said, I get the following outputs:
output, of shape
[batch_size, sequence_lengths, hidden_layer_dimension]
hidden state, of shape
[batch_size, hidden_layer_dimension]
cell state, of shape
[batch_size, hidden_layer_dimension]
Now, from the output, I take for each sequence the last relevant element, that is the element along the sequence_lengths
dimension corresponding to the last non padded element of the sequence. Thus the final output I get is of shape [batch_size, hidden_layer_dimension]
.
But now I have the problem of decoding it from this vector. How do I handle a decoding of sequences of different lengths in the same batch? I tried to google it and found this, but they don't seem to address the problem. I thought of doing element by element for the whole batch, but then I have the problem to pass the initial hidden states, given that the ones from the encoder will be of shape [batch_size, hidden_layer_dimension]
, while the ones from the decoder will be of shape [1, hidden_layer_dimension]
.
Am I missing something? Thanks for the help!
You are not missing anything. I can help you since I have worked on several sequence-to-sequence application using PyTorch. I am giving you a simple example below.
class Seq2Seq(nn.Module):
"""A Seq2seq network trained on predicting the next query."""
def __init__(self, dictionary, embedding_index, args):
super(Seq2Seq, self).__init__()
self.config = args
self.num_directions = 2 if self.config.bidirection else 1
self.embedding = EmbeddingLayer(len(dictionary), self.config)
self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize)
self.encoder = Encoder(self.config.emsize, self.config.nhid_enc, self.config.bidirection, self.config)
self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary),
self.config)
@staticmethod
def compute_decoding_loss(logits, target, seq_idx, length):
losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze()
mask = helper.mask(length, seq_idx) # mask: batch x 1
losses = losses * mask.float()
num_non_zero_elem = torch.nonzero(mask.data).size()
if not num_non_zero_elem:
return losses.sum(), 0 if not num_non_zero_elem else losses.sum(), num_non_zero_elem[0]
def forward(self, q1_var, q1_len, q2_var, q2_len):
# encode the query
embedded_q1 = self.embedding(q1_var)
encoded_q1, hidden = self.encoder(embedded_q1, q1_len)
if self.config.bidirection:
if self.config.model == 'LSTM':
h_t, c_t = hidden[0][-2:], hidden[1][-2:]
decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat(
(c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2)
else:
h_t = hidden[0][-2:]
decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2)
else:
if self.config.model == 'LSTM':
decoder_hidden = hidden[0][-1], hidden[1][-1]
else:
decoder_hidden = hidden[-1]
decoding_loss, total_local_decoding_loss_element = 0, 0
for idx in range(q2_var.size(1) - 1):
input_variable = q2_var[:, idx]
embedded_decoder_input = self.embedding(input_variable).unsqueeze(1)
decoder_output, decoder_hidden = self.decoder(embedded_decoder_input, decoder_hidden)
local_loss, num_local_loss = self.compute_decoding_loss(decoder_output, q2_var[:, idx + 1], idx, q2_len)
decoding_loss += local_loss
total_local_decoding_loss_element += num_local_loss
if total_local_decoding_loss_element > 0:
decoding_loss = decoding_loss / total_local_decoding_loss_element
return decoding_loss
You can see the complete source code here. This application is about predicting users' next web-search query given the current web-search query.
The answerer to your question:
How do I handle a decoding of sequences of different lengths in the same batch?
You have padded sequences, so you can consider as all the sequences are of the same length. But when you are computing loss, you need to ignore loss for those padded terms using masking.
I have used a masking technique to achieve the same in the above example.
Also, you are absolutely correct on: you need to decode element by element for the mini-batches. The initial decoder state [batch_size, hidden_layer_dimension]
is also fine. You just need to unsqueeze it at dimension 0, to make it [1, batch_size, hidden_layer_dimension]
.
Please note, you do not need to loop over each example in the batch, you can execute the whole batch at a time, but you need to loop over the elements of the sequences.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With