I have recently started using tensorflow and playing around with tf.where() function. I noticed that it throws up error whenever I use "==" condition. For example, when I tried the following:
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
it throws up the following error:
InvalidArgumentError: WhereOp : Unhandled input dimensions: 0
Could anyone please explain what might be the mistake here? Thanks in advance!
You need tf.equal
to do element-wise comparison:
t2 = tf.where(tf.equal(t, 2))
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
# [array([2])]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With