I am looking for a way to have a conditional print node in a tensorflow, using sample line of code below, in which every 10 loop count, it should print something in the console. But it does not work for me. Can anyone advise?
Thanks, Hamidreza,
epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)
I encountered a similar issue and solved it with the following ugly solution.
I am using tf.print instead of 'tf.Print' since it is decrypted:
Since true_func and false_func should return the same type and shape, so I return a meaningless const.
def true_func():
printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
with tf.control_dependencies([printfunc]):
return tf.constant(1)
def false_func():
return tf.constant(1)
shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)
Then you can run the shouldprint op or add it as a dependency of other ops.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With