Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Zero all values except max in Tensorflow

I have an array [0.3, 0.5, 0.79, 0.2, 0.11].

I want to convert all values to zero except the max value. So the resulting array would be: [0, 0, 0.79, 0, 0]

What would be the best way to do this in a Tensorflow graph?

like image 647
Shayan RC Avatar asked Jun 26 '26 22:06

Shayan RC


1 Answers

If you want to keep all occurences of the maximum, you could use

cond = tf.equal(a, tf.reduce_max(a))
a_max = tf.where(cond, a, tf.zeros_like(a))

If you want to keep only one occurrence of the maximum, you could use

argmax = tf.argmax(a)
a_max = tf.scatter_nd([[argmax]], [a[argmax]], tf.to_int64(tf.shape(a)))

However according to the doc of tf.argmax,

Note that in case of ties the identity of the return value is not guaranteed

As I understand it, the maximum that is kept may not be the first or the last -- and may not even be the same if run twice on the same array.

like image 193
P-Gn Avatar answered Jun 28 '26 12:06

P-Gn



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!