Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Conditional print node in tensorflow

I am looking for a way to have a conditional print node in a tensorflow, using sample line of code below, in which every 10 loop count, it should print something in the console. But it does not work for me. Can anyone advise?

Thanks, Hamidreza,

epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)
like image 586
HamidReza Mirkhani Avatar asked Nov 23 '25 00:11

HamidReza Mirkhani


1 Answers

I encountered a similar issue and solved it with the following ugly solution.

I am using tf.print instead of 'tf.Print' since it is decrypted:

Since true_func and false_func should return the same type and shape, so I return a meaningless const.

def true_func():
    printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
    with tf.control_dependencies([printfunc]):
        return tf.constant(1) 

def false_func():
    return tf.constant(1)

shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)

Then you can run the shouldprint op or add it as a dependency of other ops.

like image 85
hsc Avatar answered Nov 24 '25 12:11

hsc