Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to freeze only certain embedding weights in the embedding layer in pytorch?

When using GloVe embedding in NLP tasks, some words from the dataset might not exist in GloVe. Therefore, we instantiate random weights for these unknown words.

Would it be possible to freeze weights gotten from GloVe, and train only the newly instantiated weights?

I am only aware that we can set: model.embedding.weight.requires_grad = False

But this makes the new words untrainable..

Or are there better ways to extract semantics of words..

like image 767
rcshon Avatar asked Feb 28 '19 11:02

rcshon


People also ask

How do you freeze specific weights in PyTorch?

In PyTorch we can freeze the layer by setting the requires_grad to False. The weight freeze is helpful when we want to apply a pretrained model.

What does embedding layer do in PyTorch?

Uses of PyTorch Embedding We can say that the embedding layer works like a lookup table where each word are converted to numbers, and these numbers can be used to make up the table. Thus, keys are represented by words, and the values are word vectors.

How do you freeze layers in Pretrained model PyTorch?

We have access to all the modules, layers, and their parameters, we can easily freeze them by setting the parameters' requires_grad flag to False. This would prevent calculating the gradients for these parameters in the backward step which in turn prevents the optimizer from updating them.

Is torch nn embedding trainable?

Embedding is not for training, it's a lookup table. You first map each word in the vocabulary to a unique integer index, and then the nn. Embedding just map this index to a vector with size of 300. if 300 vector is not trained, that word vector does not represent relations among words.


1 Answers

1. Divide embeddings into two separate objects

One approach would be to use two separate embeddings one for pretrained, another for the one to be trained.

The GloVe one should be frozen, while the one for which there is no pretrained representation would be taken from the trainable layer.

If you format your data that for pretrained token representations it is in smaller range than the tokens without GloVe representation it could be done. Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500]. I would go with something along those lines:

import numpy as np
import torch


class YourNetwork(torch.nn.Module):
    def __init__(self, glove_embeddings: np.array, how_many_tokens_not_present: int):
        self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
        self.trainable_embedding = torch.nn.Embedding(
            how_many_tokens_not_present, glove_embeddings.shape[1]
        )
        # Rest of your network setup

    def forward(self, batch):
        # Which tokens in batch do not have representation, should have indices BIGGER
        # than the pretrained ones, adjust your data creating function accordingly
        mask = batch > self.pretrained_embedding.num_embeddings

        # You may want to optimize it, you could probably get away without copy, though
        # I'm not currently sure how
        pretrained_batch = batch.copy()
        pretrained_batch[mask] = 0

        embedded_batch = self.pretrained_embedding(pretrained_batch)

        # Every token without representation has to be brought into appropriate range
        batch -= self.pretrained_embedding.num_embeddings
        # Zero out the ones which already have pretrained embedding
        batch[~mask] = 0
        non_pretrained_embedded_batch = self.trainable_embedding(batch)

        # And finally change appropriate tokens from placeholder embedding created by
        # pretrained into trainable embeddings.
        embedded_batch[mask] = non_pretrained_embedded_batch[mask]

        # Rest of your code
        ...

Let's say your pretrained indices are in the range [0, 300], while those without representation are [301, 500].

2. Zero gradients for specified tokens.

This one is a bit tricky, but I think it's pretty concise and easy to implement. So, if you obtain the indices of tokens which got no GloVe representation, you can explicitly zero their gradient after backprop, so those rows will not get updated.

import torch

embedding = torch.nn.Embedding(10, 3)
X = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])

values = embedding(X)
loss = values.mean()

# Use whatever loss you want
loss.backward()

# Let's say those indices in your embedding are pretrained (have GloVe representation)
indices = torch.LongTensor([2, 4, 5])

print("Before zeroing out gradient")
print(embedding.weight.grad)

print("After zeroing out gradient")
embedding.weight.grad[indices] = 0
print(embedding.weight.grad)

And the output of the second approach:

Before zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
After zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
like image 163
Szymon Maszke Avatar answered Sep 21 '22 01:09

Szymon Maszke