Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing Luong Attention in PyTorch

I am trying to implement the attention described in Luong et al. 2015 in PyTorch myself, but I couldn't get it work. Below is my code, I am only interested in the "general" attention case for now. I wonder if I am missing any obvious error. It runs, but doesn't seem to learn.

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(
            num_embeddings=self.output_size,
            embedding_dim=self.hidden_size
        )
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size, self.hidden_size)
        # hc: [hidden, context]
        self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
        # s: softmax
        self.Ws = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        gru_out, hidden = self.gru(embedded, hidden)

        # [0] remove the dimension of directions x layers for now
        attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
        attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
        context = torch.mm(attn_weights, encoder_outputs)

        # hc: [hidden: context]
        out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
        output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6

        return output, hidden, attn_weights

I have studied the attention implemented in

https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

and

https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb

  • The first one isn't the exact attention mechanism I am looking for. A major disadvantage is that its attention depends on the sequence length (self.attn = nn.Linear(self.hidden_size * 2, self.max_length)), which could be expensive for long sequences.
  • The second one is more similar to what's described in the paper, but still not the same as there is not tanh. Besides, it is really slow after updating it to latest version of pytorch (ref). Also I don't know why it takes the last context (ref).
like image 218
zyxue Avatar asked May 28 '18 18:05

zyxue


People also ask

What is attention layer Pytorch?

attentions provides some attentions used in natural language processing using pytorch. these attentions can used in neural machine translation, speech recognition, image captioning etc... attention allows to attend to different parts of the source sentence at each step of the output generation.

What is Bahdanau attention?

The Bahdanau attention was proposed to address the performance bottleneck of conventional encoder-decoder architectures, achieving significant improvements over the conventional approach. In this tutorial, you will discover the Bahdanau attention mechanism for neural machine translation.

What is attention in Lstm?

basic lstm gets confused between the words and sometimes can predict the wrong word. So whenever this type of situation occurs the encoder step needs to search for the most relevant information, this idea is called 'Attention'. A simple structure of the bidirectional LSTM model can be represented by the above image.


1 Answers

This version works, and it follows the definition of Luong Attention (general), closely. The main difference from that in the question is the separation of embedding_size and hidden_size, which appears to be important for training after experimentation. Previously, I made both of them the same size (256), which creates trouble for learning, and it seems that the network could only learn half the sequence.

class EncoderRNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size,
                 num_layers=1, bidirectional=False, batch_size=1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.batch_size = batch_size

        self.embedding = nn.Embedding(input_size, embedding_size)

        self.gru = nn.GRU(embedding_size, hidden_size, num_layers,
                          bidirectional=bidirectional)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def initHidden(self):
        directions = 2 if self.bidirectional else 1
        return torch.zeros(
            self.num_layers * directions,
            self.batch_size,
            self.hidden_size,
            device=DEVICE
        )


class AttnDecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0):
        super(AttnDecoderRNN, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(
            num_embeddings=output_size,
            embedding_dim=embedding_size
        )
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(embedding_size, hidden_size)
        self.attn = nn.Linear(hidden_size, hidden_size)
        # hc: [hidden, context]
        self.Whc = nn.Linear(hidden_size * 2, hidden_size)
        # s: softmax
        self.Ws = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        gru_out, hidden = self.gru(embedded, hidden)

        attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
        attn_weights = F.softmax(attn_prod, dim=1)
        context = torch.mm(attn_weights, encoder_outputs)

        # hc: [hidden: context]
        hc = torch.cat([hidden[0], context], dim=1)
        out_hc = F.tanh(self.Whc(hc))
        output = F.log_softmax(self.Ws(out_hc), dim=1)

        return output, hidden, attn_weights
like image 137
zyxue Avatar answered Sep 21 '22 14:09

zyxue