Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Taking the last state from BiLSTM (BiGRU) in PyTorch

After reading several articles, I am still quite confused about correctness of my implementation of getting last hidden states from BiLSTM.

  1. Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  2. PackedSequence for seq2seq model (PyTorch forums)
  3. What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  4. Select tensor in a batch of sequences (Pytorch formums)

The approach from the last source (4) seems to be the cleanest for me, but I am still uncertain if I understood the thread correctly. Am I using the right final hidden states from LSTM and reversed LSTM? This is my implementation

# pos contains indices of words in embedding matrix
# seqlengths contains info about sequence lengths
# so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
# seqlengths contains [3,2], we have batch with samples
# of variable length [4,6,9] and [3,1]

all_in_embs = self.in_embeddings(pos)
in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
output,lasthidden = self.rnn(in_emb_seqs)
if not self.data_processor.use_gru:
    lasthidden = lasthidden[0]
# u_emb_batch has shape batch_size x embedding_dimension
# sum last state from forward and backward  direction
u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

Is it correct?

like image 531
Smarty77 Avatar asked Jun 14 '18 11:06

Smarty77


Video Answer


1 Answers

In a general case if you want to create your own BiLSTM network, you need to create two regular LSTMs, and feed one with the regular input sequence, and the other with inverted input sequence. After you finish feeding both sequences, you just take the last states from both nets and somehow tie them together (sum or concatenate).

As I understand, you are using built-in BiLSTM as in this example (setting bidirectional=True in nn.LSTM constructor). Then you get the concatenated output after feeding the batch, as PyTorch handles all the hassle for you.

If it is the case, and you want to sum the hidden states, then you have to

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

assuming you have only one layer. If you have more layers, your variant seem better.

This is because the result is structured (see documentation):

h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len

By the way,

u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

should provide the same result.

like image 121
igrinis Avatar answered Sep 22 '22 12:09

igrinis