trying here to make an eager exec model work with LR decay, but no success. It seems to be a bug, since it appear that the learning rate decay tensor does not get updated. If I am missing something can you land a hand here. Thanks.
The code bellow is learning some word embeddings. However, the learning rate decay section does not work at all.
class Word2Vec(tf.keras.Model):
def __init__(self, vocab_size, embed_size, num_sampled=NUM_SAMPLED):
self.vocab_size = vocab_size
self.num_sampled = num_sampled
self.embed_matrix = tfe.Variable(tf.random_uniform(
[vocab_size, embed_size]), name="embedding_matrix")
self.nce_weight = tfe.Variable(tf.truncated_normal(
[vocab_size, embed_size],
stddev=1.0 / (embed_size ** 0.5)), name="weights")
self.nce_bias = tfe.Variable(tf.zeros([vocab_size]), name="biases")
def compute_loss(self, center_words, target_words):
"""Computes the forward pass of word2vec with the NCE loss."""
embed = tf.nn.embedding_lookup(self.embed_matrix, center_words)
loss = tf.reduce_mean(tf.nn.nce_loss(weights=self.nce_weight,
biases=self.nce_bias,
labels=target_words,
inputs=embed,
num_sampled=self.num_sampled,
num_classes=self.vocab_size))
return loss
def gen():
yield from word2vec_utils.batch_gen(DOWNLOAD_URL, EXPECTED_BYTES,
VOCAB_SIZE, BATCH_SIZE, SKIP_WINDOW,
VISUAL_FLD)
def main():
dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.int32),
(tf.TensorShape([BATCH_SIZE]),
tf.TensorShape([BATCH_SIZE, 1])))
global_step = tf.train.get_or_create_global_step()
starter_learning_rate = 1.0
end_learning_rate = 0.01
decay_steps = 1000
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step.numpy(),
decay_steps, end_learning_rate,
power=0.5)
train_writer = tf.contrib.summary.create_file_writer('./checkpoints')
train_writer.set_as_default()
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.95)
model = Word2Vec(vocab_size=VOCAB_SIZE, embed_size=EMBED_SIZE)
grad_fn = tfe.implicit_value_and_gradients(model.compute_loss)
total_loss = 0.0 # for average loss in the last SKIP_STEP steps
checkpoint_dir = "./checkpoints/"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tfe.Checkpoint(optimizer=optimizer,
model=model,
optimizer_step=tf.train.get_or_create_global_step())
while global_step < NUM_TRAIN_STEPS:
for center_words, target_words in tfe.Iterator(dataset):
with tf.contrib.summary.record_summaries_every_n_global_steps(100):
if global_step >= NUM_TRAIN_STEPS:
break
loss_batch, grads = grad_fn(center_words, target_words)
tf.contrib.summary.scalar('loss', loss_batch)
tf.contrib.summary.scalar('learning_rate', learning_rate)
# print(grads)
# print(len(grads))
total_loss += loss_batch
optimizer.apply_gradients(grads, global_step)
if (global_step.numpy() + 1) % SKIP_STEP == 0:
print('Average loss at step {}: {:5.1f}'.format(
global_step.numpy(), total_loss / SKIP_STEP))
total_loss = 0.0
root.save(file_prefix=checkpoint_prefix)
if __name__ == '__main__':
main()
Note that when eager execution is enabled, the tf.Tensor objects represent concrete values (as opposed to symbolic handles of computation that will occur on Session.run() calls).
As a result, in your code snippet above, the line:
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step.numpy(),
decay_steps, end_learning_rate,
power=0.5)
is computing the decayed value once, using the global_step at the time it was invoked, and when the optimizer is being created with:
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.95)
it is being given a fixed learning rate.
To decay the learning rate, you'd want to invoke tf.train.polynomial_decay repeatedly (with updated values for global_step). One way to do this would be to replicate what is done in the RNN example, using something like this:
starter_learning_rate = 1.0
learning_rate = tfe.Variable(starter_learning_rate)
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.95)
while global_step < NUM_TRAIN_STEPS:
# ....
learning_rate.assign(tf.train.polynomial_decay(starter_learning_rate, global_step, decay_steps, end_learning_rate, power=0.5))
This way you've captured the learning_rate in a variable that can be updated. Furthermore, it's simple to include the current learning_rate in the checkpoint as well (by including it when creating the Checkpoint object).
Hope that helps.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With