I would like to create an LSTM
class by myself, however, I don't want to rewrite the classic LSTM
functions from scratch again.
Digging in the code of PyTorch, I only find a dirty implementation involving at least 3-4 classes with inheritance:
Does a clean PyTorch implementation of an LSTM
exist somewhere? Any links would help.
For example, I know that clean implementations of a LSTM
exists in TensorFlow, but I would need to derive a PyTorch one.
For a clear example, what I'm searching for is an implementation as clean as this, but in PyTorch:
The best implementation I found is here
https://github.com/pytorch/benchmark/blob/master/rnns/benchmarks/lstm_variants/lstm.py
It even implements four different variants of recurrent dropout, which is very useful!
If you take the dropout parts away you get
import math
import torch as th
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x, hidden):
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
preact = self.i2h(x) + self.h2h(h)
# activations
gates = preact[:, :3 * self.hidden_size].sigmoid()
g_t = preact[:, 3 * self.hidden_size:].tanh()
i_t = gates[:, :self.hidden_size]
f_t = gates[:, self.hidden_size:2 * self.hidden_size]
o_t = gates[:, -self.hidden_size:]
c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
h_t = th.mul(o_t, c_t.tanh())
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
return h_t, (h_t, c_t)
PS: The repository contains many more variants of LSTM and other RNNs:
https://github.com/pytorch/benchmark/tree/master/rnns/benchmarks.
Check it out, maybe the extension you had in mind is already there!
EDIT:
As mentioned in the comments, you can wrap the LSTM cell above to process sequential output:
import math
import torch as th
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
# As before
def reset_parameters(self):
# As before
def forward(self, x, hidden):
if hidden is None:
hidden = self._init_hidden(x)
# Rest as before
@staticmethod
def _init_hidden(input_):
h = th.zeros_like(input_.view(1, input_.size(1), -1))
c = th.zeros_like(input_.view(1, input_.size(1), -1))
return h, c
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super().__init__()
self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
def forward(self, input_, hidden=None):
# input_ is of dimensionalty (1, time, input_size, ...)
outputs = []
for x in torch.unbind(input_, dim=1):
hidden = self.lstm_cell(x, hidden)
outputs.append(hidden[0].clone())
return torch.stack(outputs, dim=1)
I havn't tested the code since I'm working with a convLSTM implementation. Please let me know if something is wrong.
UPDATE: Fixed links.
I made a simple and general frame to customize LSTMs: https://github.com/daehwannam/pytorch-rnn-util
You can implement custom LSTMs by designing LSTM cells and providing them to LSTMFrame
.
An example of custom LSTM is LayerNormLSTM
in the package:
# snippet from rnn_util/seq.py
class LayerNormLSTM(LSTMFrame):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0, r_dropout=0, bidirectional=False, layer_norm_enabled=True):
r_dropout_layer = nn.Dropout(r_dropout)
rnn_cells = tuple(
tuple(
LayerNormLSTMCell(
input_size if layer_idx == 0 else hidden_size * (2 if bidirectional else 1),
hidden_size,
dropout=r_dropout_layer,
layer_norm_enabled=layer_norm_enabled)
for _ in range(2 if bidirectional else 1))
for layer_idx in range(num_layers))
super().__init__(rnn_cells, dropout, bidirectional)
LayerNormLSTM
has the key options of PyTorch's standard LSTM and additional options, r_dropout
and layer_norm_enabled
:
# example.py
import torch
import rnn_util
bidirectional = True
num_directions = 2 if bidirectional else 1
rnn = rnn_util.LayerNormLSTM(10, 20, 2, dropout=0.3, r_dropout=0.25,
bidirectional=bidirectional, layer_norm_enabled=True)
# rnn = torch.nn.LSTM(10, 20, 2, bidirectional=bidirectional)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2 * num_directions, 3, 20)
c0 = torch.randn(2 * num_directions, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.size())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With