Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make sure your computation graph is differentiable

Some of the Tensorflow operations (e.g. tf.argmax) are not differentiable (i.e. no gradients are calculated and used in back-propagation).

An answer to Tensorflow what operations are differentiable and what are not? suggests searching for RegisterGradient in the Tensorflow code. I also noticed Tensorflow has a tf.NotDifferentiable API call for declaring an operation to be non-differentiable.

Is there a warning issued if I use non-differentiable functions? Is there a programmatic way to ensure that my entire computation graph is differentiable?

like image 528
eitanrich Avatar asked Nov 22 '25 22:11

eitanrich


1 Answers

Most floating point operations will have gradients, so a first pass answer would just be to check that there are no int32/int64 dtype Tensors in the graph. This is easy to do, but probably not useful (i.e. any non-trivial model will be doing non-differentiable indexing operations).

You could do some type of introspection, looping over the operations in the GraphDef and checking that they have gradients registered. I would argue that this is not terribly useful either; if we don't trust that gradients are registered in the first place, why trust that they're correct if registered?

Instead, I'd go with numerical gradient checking at a few points for your model. For example, let's say we register a PyFunc without a gradient:

import tensorflow as tf
import numpy
def my_func(x):
  return numpy.sinh(x)
with tf.Graph().as_default():
  inp = tf.placeholder(tf.float32)
  y = tf.py_func(my_func, [inp], tf.float32) + inp
  grad, = tf.gradients(y, inp)
  with tf.Session() as session:
    print(session.run([y, grad], feed_dict={inp: 3}))
    print("Gradient error:", tf.test.compute_gradient_error(inp, [], y, []))

This gets me output like:

[13.017875, 1.0]
Gradient error: 1.10916996002

Numerical gradients can be a bit tricky, but generally any gradient error which is more than a few orders of magnitude more than the machine epsilon (~1e-7 for float32) would raise red flags for me for a supposedly smooth function.

like image 62
Allen Lavoie Avatar answered Nov 26 '25 17:11

Allen Lavoie



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!