Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you use PyTorch PackedSequence in code?

Can someone give a full working code (not a snippet, but something that runs on a variable-length recurrent neural network) on how would you use the PackedSequence method in PyTorch?

There do not seem to be any examples of this in the documentation, github, or the internet.

https://github.com/pytorch/pytorch/releases/tag/v0.1.10

like image 392
mikal94305 Avatar asked Jun 20 '17 03:06

mikal94305


2 Answers

Not the most beautiful piece of code, but this is what I gathered for my personal use after going through PyTorch forums and docs. There can be certainly better ways to handle the sorting - restoring part, but I chose it to be in the network itself

EDIT: See answer from @tusonggao which makes torch utils take care of sorting parts

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, embedding_vectors=None, tune_embeddings=True, use_gru=True,
                 hidden_size=128, num_layers=1, bidrectional=True, dropout=0.6):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        self.embed.weight.requires_grad = tune_embeddings
        if embedding_vectors is not None:
            assert embedding_vectors.shape[0] == vocab_size and embedding_vectors.shape[1] == embedding_size
            self.embed.weight = nn.Parameter(torch.FloatTensor(embedding_vectors))
        cell = nn.GRU if use_gru else nn.LSTM
        self.rnn = cell(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                        batch_first=True, bidirectional=True, dropout=dropout)

    def forward(self, x, x_lengths):
        sorted_seq_lens, original_ordering = torch.sort(torch.LongTensor(x_lengths), dim=0, descending=True)
        ex = self.embed(x[original_ordering])
        pack = torch.nn.utils.rnn.pack_padded_sequence(ex, sorted_seq_lens.tolist(), batch_first=True)
        out, _ = self.rnn(pack)
        unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1)
                                                                       .expand(unpacked.size(0), unpacked.size(2))
                                                                       .unsqueeze(1))
        last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1)
        scatter_indices = Variable(original_ordering.view(-1, 1).expand_as(last_encoded_states))
        encoded_reordered = last_encoded_states.clone().scatter_(dim=0, index=scatter_indices, src=last_encoded_states)
        return encoded_reordered
like image 126
chiragjn Avatar answered Nov 15 '22 10:11

chiragjn


Actually there is no need to mind the sorting - restoring problem yourself, let the torch.nn.utils.rnn.pack_padded_sequence function do all the work, by setting the parameter enforce_sorted=False.

Then the returned PackedSequence object will carry the sorting related info in its sorted_indices and unsorted_indicies attributes, which can be used properly by the followed nn.GRU or nn.LSTM to restore the original index order.

Runnable code example:

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

data = [torch.tensor([1]),
        torch.tensor([2, 3, 4, 5]), 
        torch.tensor([6, 7]),
        torch.tensor([8, 9, 10])]
lengths = [d.size(0) for d in data]

padded_data = pad_sequence(data, batch_first=True, padding_value=0) 
embedding = nn.Embedding(20, 5, padding_idx=0)
embeded_data = embedding(padded_data)

packed_data = pack_padded_sequence(embeded_data, lengths, batch_first=True, enforce_sorted=False)
lstm = nn.LSTM(5, 5, batch_first=True)
o, (h, c) = lstm(packed_data)

# (h, c) is the needed final hidden and cell state, with index already restored correctly by LSTM.
# but o is a PackedSequence object, to restore to the original index:

unpacked_o, unpacked_lengths = pad_packed_sequence(o, batch_first=True)
# now unpacked_o, (h, c) is just like the normal output you expected from a lstm layer.

print(unpacked_o, unpacked_lengths)

We get the output of unpacked_o, unpacked_lengths something like follows:

# output (unpacked_o, unpacked_lengths):
tensor([[[ 1.5230, -1.7530,  0.5462,  0.6078,  0.9440],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 1.8888, -0.5465,  0.5404,  0.4132, -0.3266],
         [ 0.1657,  0.5875,  0.4556, -0.8858,  1.1443],
         [ 0.8957,  0.8676, -0.6614,  0.6751, -1.2377],
         [-1.8999,  2.8260,  0.1650, -0.6244,  1.0599]],

        [[ 0.0637,  0.3936, -0.4396, -0.2788,  0.1282],
         [ 0.5443,  0.7401,  1.0287, -0.1538, -0.2202],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.5008,  2.1262, -0.3623,  0.5864,  0.9871],
         [-0.6996, -0.3984,  0.4890, -0.8122, -1.0739],
         [ 0.3392,  1.1305, -0.6669,  0.5054, -1.7222],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<IndexSelectBackward>) tensor([1, 4, 2, 3])

Comparing it with the original data and lengths, we can find the sorting - restoring problem has been neatly taken care of.

like image 39
tusonggao Avatar answered Nov 15 '22 10:11

tusonggao