Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing gradient descent in TensorFlow instead of using the one provided with it

I want to use gradient descent with momentum (keep track of previous gradients) while building a classifier in TensorFlow.

So I don't want to use tensorflow.train.GradientDescentOptimizer but I want to use tensorflow.gradients to calculate gradients and keep track of previous gradients and update the weights based on all of them.

How do I do this in TensorFlow?

like image 291
prepmath Avatar asked Aug 26 '16 13:08

prepmath


1 Answers

TensorFlow has an implementation of gradient descent with momentum.

To answer your general question about implementing your own optimization algorithm, TensorFlow gives you the primitives to calculate the gradients, and update variables using the calculated gradients. In your model, suppose loss designates the loss function, and var_list is a python list of TensorFlow variables in your model (which you can get by calling tf.all_variables or tf.trainable_variables, then you can calculate the gradients w.r.t your variables as follows :

grads = tf.gradients(loss, var_list)

For the simple gradient descent, you would simply subtract the product of the gradient and the learning rate from the variable. The code for that would look as follows :

var_updates = []
for grad, var in zip(grads, var_list):
  var_updates.append(var.assign_sub(learning_rate * grad))
train_op = tf.group(*var_updates)

You can train your model by calling sess.run(train_op). Now, you can do all sorts of things before actually updating your variables. For instance, you can keep track of the gradients in a different set of variables and use it for the momentum algorithm. Or, you can clip your gradients before updating the variables. All these are simple TensorFlow operations because the gradient tensors are no different from other tensors that you compute in TensorFlow. Please look at the implementations (Momentum, RMSProp, Adam) of some the fancier optimization algorithms to understand how you can implement your own.

like image 182
keveman Avatar answered Sep 25 '22 01:09

keveman