Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Issues with using == condition in tf.where()

Tags:

tensorflow

I have recently started using tensorflow and playing around with tf.where() function. I noticed that it throws up error whenever I use "==" condition. For example, when I tried the following:

t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])

t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)

t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))

it throws up the following error:

InvalidArgumentError: WhereOp : Unhandled input dimensions: 0

Could anyone please explain what might be the mistake here? Thanks in advance!

like image 480
Saurabh Agrawal Avatar asked Jul 24 '18 00:07

Saurabh Agrawal


1 Answers

You need tf.equal to do element-wise comparison:

t2 = tf.where(tf.equal(t, 2))

t = tf.constant([[1, 2, 3],
                 [4, 5, 6]])

t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)   
t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))

# [array([2])]
like image 131
Psidom Avatar answered Oct 23 '22 18:10

Psidom