Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to correctly implement a batch-input LSTM network in PyTorch?

Tags:

pytorch

This release of PyTorch seems provide the PackedSequence for variable lengths of input for recurrent neural network. However, I found it's a bit hard to use it correctly.

Using pad_packed_sequence to recover an output of a RNN layer which were fed by pack_padded_sequence, we got a T x B x N tensor outputs where T is the max time steps, B is the batch size and N is the hidden size. I found that for short sequences in the batch, the subsequent output will be all zeros.

Here are my questions.

  1. For a single output task where the one would need the last output of all the sequences, simple outputs[-1] will give a wrong result since this tensor contains lots of zeros for short sequences. One will need to construct indices by sequence lengths to fetch the individual last output for all the sequences. Is there more simple way to do that?
  2. For a multiple output task (e.g. seq2seq), usually one will add a linear layer N x O and reshape the batch outputs T x B x O into TB x O and compute the cross entropy loss with the true targets TB (usually integers in language model). In this situation, do these zeros in batch output matters?
like image 810
Saddle Point Avatar asked Sep 24 '17 07:09

Saddle Point


1 Answers

Question 1 - Last Timestep

This is the code that i use to get the output of the last timestep. I don't know if there is a simpler solution. If it is, i'd like to know it. I followed this discussion and grabbed the relative code snippet for my last_timestep method. This is my forward.

class BaselineRNN(nn.Module):
    def __init__(self, **kwargs):
        ...

    def last_timestep(self, unpacked, lengths):
        # Index of the last output for each sequence.
        idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
                                               unpacked.size(2)).unsqueeze(1)
        return unpacked.gather(1, idx).squeeze()

    def forward(self, x, lengths):
        embs = self.embedding(x)

        # pack the batch
        packed = pack_padded_sequence(embs, list(lengths.data),
                                      batch_first=True)

        out_packed, (h, c) = self.rnn(packed)

        out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)

        # get the outputs from the last *non-masked* timestep for each sentence
        last_outputs = self.last_timestep(out_unpacked, lengths)

        # project to the classes using a linear layer
        logits = self.linear(last_outputs)

        return logits

Question 2 - Masked Cross Entropy Loss

Yes, by default the zero padded timesteps (targets) matter. However, it is very easy to mask them. You have two options, depending on the version of PyTorch that you use.

  1. PyTorch 0.2.0: Now pytorch supports masking directly in the CrossEntropyLoss, with the ignore_index argument. For example, in language modeling or seq2seq, where i add zero padding, i mask the zero padded words (target) simply like this:

    loss_function = nn.CrossEntropyLoss(ignore_index=0)

  2. PyTorch 0.1.12 and older: In the older versions of PyTorch, masking was not supported, so you had to implement your own workaround. I solution that i used, was masked_cross_entropy.py, by jihunchoi. You may be also interested in this discussion.

like image 116
Christos Baziotis Avatar answered Sep 30 '22 23:09

Christos Baziotis