Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow's Print or K.print_tensor are not printing intermediate tensors in loss function

I have written a rather complex loss function for a Keras model and it keeps returning nan while training. Therefore, I need to print the intermediate tensors while training. I understand that you cannot do K.eval in your loss function because the tensors are not initialized. However, I have tried both K.print_tensor() and tf.Print() and neither work.

Pretty much I want to do something like this:

def mean_squared_error(y_true, y_pred):
    print("mean_squared_error")
    loss = K.mean(K.square(y_pred - y_true), axis=-1)
    loss = tf.Print(loss, [loss])
    return loss
model.compile(optimizer=self.optimizer, loss=mean_squared_error)

In practice, I would replace mean_squared_error with my custom loss. "mean_squared_error" would get printed, but not the values I try to print using TensorFlow print (nor Keras print). I also tried the exact same code as in How do I print inside the loss function during training in Keras? and I still don't see anything getting printed in the console.

In addition, I have written a separate file to test something.

import tensorflow as tf
import keras.backend as K

input1 = K.constant(1)
input2 = K.constant(2)
input3 = K.constant(3)

node1 = tf.add(input1, input2)
print_output = K.print_tensor(node1)
output = tf.multiply(print_output, input3)

Nothing gets printed either.

Am I using TensorFlow's Print and Keras print_tensor wrongly? Or are the results printed elsewhere? I have tried to test for my console's stderr using print("test", file=sys.stderr) and got the correct output test.

For clarification, I know that you can use K.eval to make the test code print out values of the tensor, but since I cannot use K.eval in my loss function, I need to make tf.Print or K.print_tensor work.

like image 618
Leo Appleseed Avatar asked Nov 08 '22 03:11

Leo Appleseed


1 Answers

The issue here is that the training code often does not actually depend on the value of the loss tensor! Usually you can compute the gradient of a loss without ever computing the actual value of the loss, and this means tensorflow's runtime is free to prune the actual execution of the loss from the graph.

You can wrap your loss function in a tf.contrib.eager.defun decorator, which has the side effect of guaranteeing that all stateful ops in your function run even if they are not needed by the backward pass.

like image 176
Alexandre Passos Avatar answered Nov 14 '22 23:11

Alexandre Passos