Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I intercept the gradient from automatic differentiation in TensorFlow?

Tags:

tensorflow

Let's say I have two subsequent layers with activations a1 and a2. Is there a way to intercept the gradients that automatic differentiation propagates from layer 2 to layer 1, i.e. ∂E/∂a2? I would like to change this gradient and then pass it on to layer 1.

like image 262
Lenar Hoyt Avatar asked Mar 12 '23 07:03

Lenar Hoyt


2 Answers

From tf.train.Optimizer documentation,

Processing gradients before applying them.

Calling minimize() takes care of both computing the gradients and applying them to the variables. If you want to process the gradients before applying them you can instead use the optimizer in three steps:

Compute the gradients with compute_gradients(). Process the gradients as you wish. Apply the processed gradients with apply_gradients(). Example:

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)
like image 195
DomJack Avatar answered Apr 28 '23 07:04

DomJack


You might be looking for tf.Graph.gradient_override_map. There is a good example in the tensorflow docs:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
# ...

with tf.Graph().as_default() as g:
  c = tf.constant(5.0)
  s_1 = tf.square(c)  # Uses the default gradient for tf.square.
  with g.gradient_override_map({"Square": "CustomSquare"}):
    s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
                      # gradient of s_2.

There is a real world use of it here to pass the real valued gradient back through quantized weights in a do-re-fa net implementation.

like image 20
user728291 Avatar answered Apr 28 '23 07:04

user728291