Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I use LSTM in pytorch for classification?

Tags:

pytorch

My code is as below:

class Mymodel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, batch_size):
        super(Discriminator, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.batch_size = batch_size

        self.lstm = nn.LSTM(input_size, hidden_size)
        self.proj = nn.Linear(hidden_size, output_size)
        self.hidden = self.init_hidden()


    def init_hidden(self):
        return (Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)),
                Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)))

    def forward(self, x):
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        output = self.proj(lstm_out)
        result = F.sigmoid(output)
        return result

I want to use LSTM to classify a sentence to good (1) or bad (0). Using this code, I get the result which is time_step * batch_size * 1 but not 0 or 1. How to edit the code in order to get the classification result?

like image 880
Zhao Wulanaren Avatar asked Jan 04 '23 00:01

Zhao Wulanaren


1 Answers

Theory:

Recall that an LSTM outputs a vector for every input in the series. You are using sentences, which are a series of words (probably converted to indices and then embedded as vectors). This code from the LSTM PyTorch tutorial makes clear exactly what I mean (***emphasis mine):

lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [autograd.Variable(torch.randn((1, 3)))
          for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (autograd.Variable(torch.randn(1, 1, 3)),
          autograd.Variable(torch.randn((1, 1, 3))))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# *** (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (autograd.Variable(torch.randn(1, 1, 3)), autograd.Variable(
torch.randn((1, 1, 3))))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

One more time: compare the last slice of "out" with "hidden" below, they are the same. Why? Well...

If you're familiar with LSTM's, I'd recommend the PyTorch LSTM docs at this point. Under the output section, notice h_t is output at every t.

Now if you aren't used to LSTM-style equations, take a look at Chris Olah's LSTM blog post. Scroll down to the diagram of the unrolled network:

Credit C Olah, "Understanding LSTM Networks"

As you feed your sentence in word-by-word (x_i-by-x_i+1), you get an output from each timestep. You want to interpret the entire sentence to classify it. So you must wait until the LSTM has seen all the words. That is, you need to take h_t where t is the number of words in your sentence.

Code:

Here's a coding reference. I'm not going to copy-paste the entire thing, just the relevant parts. The magic happens at self.hidden2label(lstm_out[-1])

class LSTMClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, label_size, batch_size):
        ...
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2label = nn.Linear(hidden_dim, label_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (autograd.Variable(torch.zeros(1, self.batch_size, self.hidden_dim)),
                autograd.Variable(torch.zeros(1, self.batch_size, self.hidden_dim)))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        x = embeds.view(len(sentence), self.batch_size , -1)
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        y  = self.hidden2label(lstm_out[-1])
        log_probs = F.log_softmax(y)
        return log_probs
like image 191
Dylan F Avatar answered Jan 28 '23 22:01

Dylan F