I'm building a simple multilayer perceptron with TensorFlow, and I also need to obtain the gradients (or error signal) of the loss at the neural network's inputs.
Here's my code, which works:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
....
for batch in batches:
...
sess.run(optimizer, feed_dict=feed_dict)
grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]
(edited to include training loop)
Without the last line (grads_wrt_input...
), this runs really fast on a CUDA machine. However, tf.gradients()
reduces performance greatly by tenfold or more.
I recall that the error signals at the nodes are computed as intermediate values in the backpropagation algorithm, and I have successfully done this using the Java library DeepLearning4j. I was also under the impression that this would be a slight modification to the computation graph already built by optimizer
.
How can this be made faster, or is there any other way to compute the gradients of the loss w.r.t. the inputs?
The tf.gradients()
function builds a new backpropagation graph each time it is called, so the reason for the slowdown is that TensorFlow has to parse a new graph on each iteration of the loop. (This can be surprisingly expensive: the current version of TensorFlow is optimized for executing the same graph a large number of times.)
Fortunately the solution is easy: just compute the gradients once, outside the loop. You can restructure your code as follows:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
grads_wrt_input_tensor = tf.gradients(cost, self.x)[0]
# ...
for i in range(epochs):
# ...
for batch in batches:
# ...
_, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor],
feed_dict=feed_dict)
Note that, for performance, I also combined the two sess.run()
calls. This ensures that the forward propagation, and much of the backpropagation, will be reused.
As an aside, one tip to find performance bugs like this is to call tf.get_default_graph().finalize()
before starting your training loop. This will raise an exception if you inadvertantly add any nodes to the graph, which makes it easier to trace the cause of these bugs.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With