Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to perform thresholding on a tensor

Following is my code:

...
result=tf.div(product_norm,denom)
    if(result>0.5):
        result=1
    else:
        result=0
    return result

If the value in a tensor is less than 0.5 then it should be replaced with 0 otherwise 1. But it keeps returning the error.

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
like image 872
user3102085 Avatar asked Jan 07 '17 16:01

user3102085


1 Answers

In this simple case, you can use

result = tf.cast(result + 0.5, tf.int32)

When if-statement becomes more complex, consider using tf.cond

like image 168
standy Avatar answered Nov 03 '22 02:11

standy