Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Update only part of the word embedding matrix in Tensorflow

Assuming that I want to update a pre-trained word-embedding matrix during training, is there a way to update only a subset of the word embedding matrix?

I have looked into the Tensorflow API page and found this:

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)

However how do I apply that to the word-embedding matrix. Suppose I do:

word_emb = tf.Variable(0.2 * tf.random_uniform([syn0.shape[0],s['es']], minval=-1.0, maxval=1.0, dtype=tf.float32),name='word_emb',trainable=False)

gather_emb = tf.gather(word_emb,indices) #assuming that I pass some indices as placeholder through feed_dict

opt = tf.train.AdamOptimizer(1e-4)
grad = opt.compute_gradients(loss,gather_emb)

How do I then use opt.apply_gradients and tf.scatter_update to update the original embeddign matrix? (Also, tensorflow throws an error if the second argument of compute_gradient is not a tf.Variable)

like image 500
user1718064 Avatar asked Mar 04 '16 18:03

user1718064


3 Answers

TL;DR: The default implementation of opt.minimize(loss), TensorFlow will generate a sparse update for word_emb that modifies only the rows of word_emb that participated in the forward pass.

The gradient of the tf.gather(word_emb, indices) op with respect to word_emb is a tf.IndexedSlices object (see the implementation for more details). This object represents a sparse tensor that is zero everywhere, except for the rows selected by indices. A call to opt.minimize(loss) calls AdamOptimizer._apply_sparse(word_emb_grad, word_emb), which makes a call to tf.scatter_sub(word_emb, ...)* that updates only the rows of word_emb that were selected by indices.

If on the other hand you want to modify the tf.IndexedSlices that is returned by opt.compute_gradients(loss, word_emb), you can perform arbitrary TensorFlow operations on its indices and values properties, and create a new tf.IndexedSlices that can be passed to opt.apply_gradients([(word_emb, ...)]). For example, you could cap the gradients using MyCapper() (as in the example) using the following calls:

grad, = opt.compute_gradients(loss, word_emb)
train_op = opt.apply_gradients(
    [tf.IndexedSlices(MyCapper(grad.values), grad.indices)])

Similarly, you could change the set of indices that will be modified by creating a new tf.IndexedSlices with a different indices.


* In general, if you want to update only part of a variable in TensorFlow, you can use the tf.scatter_update(), tf.scatter_add(), or tf.scatter_sub() operators, which respectively set, add to (+=) or subtract from (-=) the value previously stored in a variable.

like image 89
mrry Avatar answered Nov 16 '22 00:11

mrry


Since you just want to select the elements to be updated (and not to change the gradients), you can do as follows.

Let indices_to_update be a boolean tensor that indicates the indices you wish to update, and entry_stop_gradients is defined in the link, Then:

gather_emb = entry_stop_gradients(gather_emb, indices_to_update)

(Source)

like image 32
Oren Avatar answered Nov 16 '22 01:11

Oren


Actually, I was also struggling with such a problem. In my case, I needed to train a model with w2v embeddings, but not all of the tokens existed in embedding matrix. Thus for those tokens which were not in matrix, I made random initialization. Of course tokens for which embeddings were already trained, shouldn't be updated, thus I've came up with such a solution:

class PartialEmbeddingsUpdate(tf.keras.layers.Layer):
def __init__(self, len_vocab, 
             weights,
            indices_to_update):
    super(PartialEmbeddingsUpdate, self).__init__()
    self.embeddings = tf.Variable(weights, name='embedding', dtype=tf.float32)
    self.bool_mask = tf.equal(tf.expand_dims(tf.range(0,len_vocab),1), tf.expand_dims(indices_to_update,0))
    self.bool_mask = tf.reduce_any(self.bool_mask,1)
    self.bool_mask_not = tf.logical_not(self.bool_mask)
    self.bool_mask_not = tf.expand_dims(tf.cast(self.bool_mask_not, dtype=self.embeddings.dtype),1)
    self.bool_mask = tf.expand_dims(tf.cast(self.bool_mask, dtype=self.embeddings.dtype),1)
    
def call(self, input):
    input = tf.cast(input, dtype=tf.int32)
    embeddings = tf.stop_gradient(self.bool_mask_not * self.embeddings) + self.bool_mask * self.embeddings
    return tf.gather(embeddings,input)

Where len_vocab - is your vocabulary length, weights - matrix of weights (some of which shouldn't be updated) and indices_to_update - indices of those tokens which should be updated. After that I applied this layer instead of tf.keras.layers.Embeddings. Hope it helps everyone, who encountered the same problem.

like image 41
Kovenko Volodymyr Avatar answered Nov 15 '22 23:11

Kovenko Volodymyr