Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How does batching work in a seq2seq model in pytorch?

I am trying to implement a seq2seq model in Pytorch and I am having some problem with the batching. For example I have a batch of data whose dimensions are

[batch_size, sequence_lengths, encoding_dimension]

where the sequence lengths are different for each example in the batch.

Now, I managed to do the encoding part by padding each element in the batch to the length of the longest sequence.

This way if I give as input to my net a batch with the same shape as said, I get the following outputs:

output, of shape [batch_size, sequence_lengths, hidden_layer_dimension]

hidden state, of shape [batch_size, hidden_layer_dimension]

cell state, of shape [batch_size, hidden_layer_dimension]

Now, from the output, I take for each sequence the last relevant element, that is the element along the sequence_lengths dimension corresponding to the last non padded element of the sequence. Thus the final output I get is of shape [batch_size, hidden_layer_dimension].

But now I have the problem of decoding it from this vector. How do I handle a decoding of sequences of different lengths in the same batch? I tried to google it and found this, but they don't seem to address the problem. I thought of doing element by element for the whole batch, but then I have the problem to pass the initial hidden states, given that the ones from the encoder will be of shape [batch_size, hidden_layer_dimension], while the ones from the decoder will be of shape [1, hidden_layer_dimension].

Am I missing something? Thanks for the help!

like image 291
GPaolo Avatar asked Mar 14 '18 16:03


1 Answers

You are not missing anything. I can help you since I have worked on several sequence-to-sequence application using PyTorch. I am giving you a simple example below.

class Seq2Seq(nn.Module):
    """A Seq2seq network trained on predicting the next query."""

    def __init__(self, dictionary, embedding_index, args):
        super(Seq2Seq, self).__init__()

        self.config = args
        self.num_directions = 2 if self.config.bidirection else 1

        self.embedding = EmbeddingLayer(len(dictionary), self.config)
        self.embedding.init_embedding_weights(dictionary, embedding_index, self.config.emsize)

        self.encoder = Encoder(self.config.emsize, self.config.nhid_enc, self.config.bidirection, self.config)
        self.decoder = Decoder(self.config.emsize, self.config.nhid_enc * self.num_directions, len(dictionary),

    def compute_decoding_loss(logits, target, seq_idx, length):
        losses = -torch.gather(logits, dim=1, index=target.unsqueeze(1)).squeeze()
        mask = helper.mask(length, seq_idx)  # mask: batch x 1
        losses = losses * mask.float()
        num_non_zero_elem = torch.nonzero(mask.data).size()
        if not num_non_zero_elem:
        return losses.sum(), 0 if not num_non_zero_elem else losses.sum(), num_non_zero_elem[0]

    def forward(self, q1_var, q1_len, q2_var, q2_len):
        # encode the query
        embedded_q1 = self.embedding(q1_var)
        encoded_q1, hidden = self.encoder(embedded_q1, q1_len)

        if self.config.bidirection:
            if self.config.model == 'LSTM':
                h_t, c_t = hidden[0][-2:], hidden[1][-2:]
                decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2), torch.cat(
                    (c_t[0].unsqueeze(0), c_t[1].unsqueeze(0)), 2)
                h_t = hidden[0][-2:]
                decoder_hidden = torch.cat((h_t[0].unsqueeze(0), h_t[1].unsqueeze(0)), 2)
            if self.config.model == 'LSTM':
                decoder_hidden = hidden[0][-1], hidden[1][-1]
                decoder_hidden = hidden[-1]

        decoding_loss, total_local_decoding_loss_element = 0, 0
        for idx in range(q2_var.size(1) - 1):
            input_variable = q2_var[:, idx]
            embedded_decoder_input = self.embedding(input_variable).unsqueeze(1)
            decoder_output, decoder_hidden = self.decoder(embedded_decoder_input, decoder_hidden)
            local_loss, num_local_loss = self.compute_decoding_loss(decoder_output, q2_var[:, idx + 1], idx, q2_len)
            decoding_loss += local_loss
            total_local_decoding_loss_element += num_local_loss

        if total_local_decoding_loss_element > 0:
            decoding_loss = decoding_loss / total_local_decoding_loss_element

        return decoding_loss

You can see the complete source code here. This application is about predicting users' next web-search query given the current web-search query.

The answerer to your question:

How do I handle a decoding of sequences of different lengths in the same batch?

You have padded sequences, so you can consider as all the sequences are of the same length. But when you are computing loss, you need to ignore loss for those padded terms using masking.

I have used a masking technique to achieve the same in the above example.

Also, you are absolutely correct on: you need to decode element by element for the mini-batches. The initial decoder state [batch_size, hidden_layer_dimension] is also fine. You just need to unsqueeze it at dimension 0, to make it [1, batch_size, hidden_layer_dimension].

Please note, you do not need to loop over each example in the batch, you can execute the whole batch at a time, but you need to loop over the elements of the sequences.

like image 66
Wasi Ahmad Avatar answered Oct 27 '22 22:10

Wasi Ahmad