Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does tf.Print() not work?

I have this code snippet:

import tensorflow as tf
import numpy as np

    # batch x time x events x dim
batch = 2
time = 3
events = 4
tensor = np.random.rand(batch, time, events)

tensor[0][0][2] = 0
tensor[0][0][3] = 0

tensor[0][1][3] = 0

tensor[0][2][1] = 0
tensor[0][2][2] = 0
tensor[0][2][3] = 0

tensor[1][0][3] = 0

def cum_sum(prev, cur):
    non_zeros = tf.equal(cur, 0.)
    tf.Print(non_zeros, [non_zeros], "message ")
    tf.Print(cur, [cur])
    return cur

elems = tf.constant([1,2,3],dtype=tf.int64)
#alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
cum_sum_ = tf.scan(cum_sum, tensor)

s = tf.Session()

s.run(cum_sum_)

I have two tf.Print statements in the function that I pass to the tf.scan, but when I run the cumulative sum, I do not get any print statements. Am I doing something wrong?

like image 562
Naz Avatar asked Jul 19 '18 12:07

Naz


2 Answers

tf.Print does not work like that. Print nodes need to be into the graph in order to exectue. I highly recommend you to check out this tutorial to learn how to use it.

If you have any questions feel free to ask.

like image 104
Ignacio Peletier Avatar answered Oct 20 '22 01:10

Ignacio Peletier


Although, @Ignacio Peletier answer is a fully helpful one it depends on an external site. I find it peculiar that noone mentioned that here. Anyway, for sticking with the rules I provide the answer also here (without having to visit the external link that is) with some extra information:

For tf.Print to actually print something it should belong to the graph. To do so, you simply reuse the returned Tensor from the tf.Print and pass it to the next op. Passing it over to the next op is crucial for actually displaying the message.

So, to use your example it could be rewritten:

def cum_sum(prev, cur):
    non_zeros = tf.equal(cur, 0.)
    non_zeros = tf.Print(non_zeros, [non_zeros], "message ")
    cur = tf.Print(cur, [cur])
    return cur

which will print cur but not non_zeros since this node is dangling. Also, I am not sure I can rewrite your code so that non_zeros is displayed since it is not actually used in your code after the point you define them (so tensorflow will simply ignore it in non-eager mode).

In general for (a non dangling) node, let's name it result somewhere in your code:

result = tf.some_op_here()
# the following is actually displaying the contents of result
result = tf.Print(result, [result], 'some arbitrary message here to be displayed before the actual message')
tf.another_op_here_using_result(result)

will print the (possibly concatenated) contents of result. For a control over the amount of information shown you can also use the parameter summarize=x where x is the number of parameters being displayed.

like image 21
Eypros Avatar answered Oct 20 '22 01:10

Eypros