Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to tie word embedding and softmax weights in keras?

Its commonplace for various neural network architectures in NLP and vision-language problems to tie the weights of an initial word embedding layer to that of an output softmax. Usually this produces a boost to sentence generation quality. (see example here)

In Keras its typical to embed word embedding layers using the Embedding class, however there seems to be no easy way to tie the weights of this layer to the output softmax. Would anyone happen to know how this could be implemented ?

like image 464
Rohit Gupta Avatar asked Nov 03 '17 12:11

Rohit Gupta


1 Answers

Be aware that Press and Wolf dont't propose to freeze the weights to some pretrained ones, but tie them. That means, to ensure that input and output weights are always the same during training (in the sense of synchronized).

In a typical NLP model (e.g. language modelling/translation), you have an input dimension (vocabulary) of size V and a hidden representation size H. Then, you start with an Embedding layer, which is a matrix VxH. And the output layer is (probably) something like Dense(V, activation='softmax'), which is a matrix H2xV. When tying the weights, we want that those matrices are the same (therefore, H==H2). For doing this in Keras, I think the way to go is via shared layers:

In your model, you need to instantiate a shared embedding layer (of dimension VxH), and apply it to either your input and output. But you need to transpose it, to have the desired output dimensions (HxV). So, we declare a TiedEmbeddingsTransposed layer, which transposes the embedding matrix from a given layer (and applies an activation function):

class TiedEmbeddingsTransposed(Layer):
    """Layer for tying embeddings in an output layer.
    A regular embedding layer has the shape: V x H (V: size of the vocabulary. H: size of the projected space).
    In this layer, we'll go: H x V.
    With the same weights than the regular embedding.
    In addition, it may have an activation.
    # References
        - [ Using the Output Embedding to Improve Language Models](https://arxiv.org/abs/1608.05859)
    """

    def __init__(self, tied_to=None,
                 activation=None,
                 **kwargs):
        super(TiedEmbeddingsTransposed, self).__init__(**kwargs)
        self.tied_to = tied_to
        self.activation = activations.get(activation)

    def build(self, input_shape):
        self.transposed_weights = K.transpose(self.tied_to.weights[0])
        self.built = True

    def compute_mask(self, inputs, mask=None):
        return mask

    def compute_output_shape(self, input_shape):
        return input_shape[0], K.int_shape(self.tied_to.weights[0])[0]

    def call(self, inputs, mask=None):
        output = K.dot(inputs, self.transposed_weights)
        if self.activation is not None:
            output = self.activation(output)
        return output


    def get_config(self):
        config = {'activation': activations.serialize(self.activation)
                  }
        base_config = super(TiedEmbeddingsTransposed, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

The usage of this layer is:

# Declare the shared embedding layer
shared_embedding_layer = Embedding(V, H)
# Obtain word embeddings
word_embedding = shared_embedding_layer(input)
# Do stuff with your model
# Compute output (e.g. a vocabulary-size probability vector) with the shared layer:
output = TimeDistributed(TiedEmbeddingsTransposed(tied_to=shared_embedding_layer, activation='softmax')(intermediate_rep)

I have tested this in NMT-Keras and it trains properly. But, as I try to load a trained model, it gets an error, related to the way Keras loads the models: it doesn't load the weights from the tied_to. I've found several questions regarding this (1, 2, 3), but I haven't managed to solve this issue. If someone have any ideas on the next steps to take, I'd be very glad to hear them :)

like image 98
lvapeab Avatar answered Nov 14 '22 17:11

lvapeab