Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simple LSTM in PyTorch with Sequential module

In PyTorch, we can define architectures in multiple ways. Here, I'd like to create a simple LSTM network using the Sequential module.

In Lua's torch I would usually go with:

model = nn.Sequential()
model:add(nn.SplitTable(1,2))
model:add(nn.Sequencer(nn.LSTM(inputSize, hiddenSize)))
model:add(nn.SelectTable(-1)) -- last step of output sequence
model:add(nn.Linear(hiddenSize, classes_n))

However, in PyTorch, I don't find the equivalent of SelectTable to get the last output.

nn.Sequential(
  nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
  # what to put here to retrieve last output of LSTM ?,
  nn.Linear(hiddenSize, classe_n))
like image 463
BiBi Avatar asked May 23 '17 09:05

BiBi


1 Answers

Define a class to extract the last cell output:

# LSTM() returns tuple of (tensor, (recurrent state))
class extract_tensor(nn.Module):
    def forward(self,x):
        # Output shape (batch, features, hidden)
        tensor, _ = x
        # Reshape shape (batch, hidden)
        return tensor[:, -1, :]

nn.Sequential(
    nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
    extract_tensor(),
    nn.Linear(hiddenSize, classe_n)
)
like image 114
almamon Rasool Abd Ali Avatar answered Oct 11 '22 15:10

almamon Rasool Abd Ali