Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Any example of torch 0.4.0 nn.LayerNorm example for nn.LSTMCell?

Tags:

lstm

pytorch

In pytorch 0.4.0 release, there is a nn.LayerNorm module.

I want to implement this layer to my LSTM network, though I cannot find any implementation example on LSTM network yet.

And the pytorch Contributor implies that this nn.LayerNorm is only applicable through nn.LSTMCells.

It will be a great help if I can get any git repo or some code that implements nn.LayerNorm on nn.LSTMcell or any torch LSTM network.

Thanks in advance

like image 356
doyouknowkimchi Avatar asked May 03 '18 04:05

doyouknowkimchi


1 Answers

I am also looking for a solution. Here is an example from https://github.com/pytorch/pytorch/issues/11335
Thanks to @jinserk

class LayerNormLSTMCell(nn.LSTMCell):

def __init__(self, input_size, hidden_size, bias=True):
    super().__init__(input_size, hidden_size, bias)

    self.ln_ih = nn.LayerNorm(4 * hidden_size)
    self.ln_hh = nn.LayerNorm(4 * hidden_size)
    self.ln_ho = nn.LayerNorm(hidden_size)

def forward(self, input, hidden=None):
    self.check_forward_input(input)
    if hidden is None:
        hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
        cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
    else:
        hx, cx = hidden
    self.check_forward_hidden(input, hx, '[0]')
    self.check_forward_hidden(input, cx, '[1]')

    gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \
             + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh))
    i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1)
    g = gates[:, (3 * self.hidden_size):].tanh()

    cy = (f * cx) + (i * g)
    hy = o * self.ln_ho(cy).tanh()
    return hy, cy


class LayerNormLSTM(nn.Module):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

    num_directions = 2 if bidirectional else 1
    self.hidden0 = nn.ModuleList([
        LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                          hidden_size=hidden_size, bias=bias)
        for layer in range(num_layers)
    ])

    if self.bidirectional:
        self.hidden1 = nn.ModuleList([
            LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions),
                              hidden_size=hidden_size, bias=bias)
            for layer in range(num_layers)
        ])

def forward(self, input, hidden=None):
    seq_len, batch_size, hidden_size = input.size()  # supports TxNxH only
    num_directions = 2 if self.bidirectional else 1
    if hidden is None:
        hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
        cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False)
    else:
        hx, cx = hidden

    ht = [[None, ] * (self.num_layers * num_directions)] * seq_len
    ct = [[None, ] * (self.num_layers * num_directions)] * seq_len

    if self.bidirectional:
        xs = input
        for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)):
            l0, l1 = 2 * l, 2 * l + 1
            h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1]
            for t, (x0, x1) in enumerate(zip(xs, reversed(xs))):
                ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0))
                h0, c0 = ht[t][l0], ct[t][l0]
                t = seq_len - 1 - t
                ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1))
                h1, c1 = ht[t][l1], ct[t][l1]
            xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht]
        y  = torch.stack(xs)
        hy = torch.stack(ht[-1])
        cy = torch.stack(ct[-1])
    else:
        h, c = hx, cx
        for t, x in enumerate(input):
            for l, layer in enumerate(self.hidden0):
                ht[t][l], ct[t][l] = layer(x, (h[l], c[l]))
                x = ht[t][l]
            h, c = ht[t], ct[t]
        y  = torch.stack([h[-1] for h in ht])
        hy = torch.stack(ht[-1])
        cy = torch.stack(ct[-1])

    return y, (hy, cy)
like image 147
radream Avatar answered Jan 03 '23 02:01

radream