In PyTorch, we can define architectures in multiple ways. Here, I'd like to create a simple LSTM network using the Sequential
module.
In Lua's torch I would usually go with:
model = nn.Sequential()
model:add(nn.SplitTable(1,2))
model:add(nn.Sequencer(nn.LSTM(inputSize, hiddenSize)))
model:add(nn.SelectTable(-1)) -- last step of output sequence
model:add(nn.Linear(hiddenSize, classes_n))
However, in PyTorch, I don't find the equivalent of SelectTable
to get the last output.
nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
# what to put here to retrieve last output of LSTM ?,
nn.Linear(hiddenSize, classe_n))
Define a class to extract the last cell output:
# LSTM() returns tuple of (tensor, (recurrent state))
class extract_tensor(nn.Module):
def forward(self,x):
# Output shape (batch, features, hidden)
tensor, _ = x
# Reshape shape (batch, hidden)
return tensor[:, -1, :]
nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
extract_tensor(),
nn.Linear(hiddenSize, classe_n)
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With