Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

LSTM in Pytorch

I'm new to PyTorch. I came across some this GitHub repository (link to full code example) containing various different examples.

There is also an example about LSTMs, this is the Network class:

# RNN Model (Many-to-One)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial states 
        h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) 
        c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))

        # Forward propagate RNN
        out, _ = self.lstm(x, (h0, c0))  

        # Decode hidden state of last time step
        out = self.fc(out[:, -1, :])  
        return out

So my question is about the following lines:

h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) 
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))

As far as I understand it, forward() is called for every training example. But this would mean, that the hidden state and cell state would be resettet i.e. replaced with a matrix of zeros on every training example.

The names h0 and c0 indicate that this is only the hidden/cell state at t=0, but why then are theses zeros matrices handed over to the lstm with every training example?

Even if they are just ignored after the first call, it would not be a very nice solution.

When testing the code it states an accuracy of 97% on the MNIST set, so it seems to work this way, but it doesn't make sense to me.

Hope someone can help me out with this.

Thanks in advance!

like image 756
MBT Avatar asked Feb 16 '18 17:02

MBT


People also ask

How does LSTM work in Pytorch?

An LSTM cell takes the following inputs: input, (h_0, c_0). input : a tensor of inputs of shape (batch, input_size) , where we declared input_size in the creation of the LSTM cell. h_0 : a tensor containing the initial hidden state for each element in the batch, of shape (batch, hidden_size).

What is the output of LSTM layer Pytorch?

The output of the Pytorch LSTM layer is a tuple with two elements.

What is LSTM and how it works?

LSTM Explained It is a variety of recurrent neural networks (RNNs) that are capable of learning long-term dependencies, especially in sequence prediction problems. LSTM has feedback connections, i.e., it is capable of processing the entire sequence of data, apart from single data points such as images.

What is LSTM in neural network?

Long short-term memory (LSTM) is an artificial neural network used in the fields of artificial intelligence and deep learning. Unlike standard feedforward neural networks, LSTM has feedback connections.


1 Answers

Obviously I was on the wrong track with this. I was confusing hidden units and hidden/cell state. Only the hidden units in the LSTM are trained during the training step. Cell state and hidden state are resetet at the beginning of every sequence. So it just makes sense that it is programmed this way.

Sorry for this..

like image 97
MBT Avatar answered Sep 29 '22 11:09

MBT