I have an array [0.3, 0.5, 0.79, 0.2, 0.11].
I want to convert all values to zero except the max value. So the resulting array would be:
[0, 0, 0.79, 0, 0]
What would be the best way to do this in a Tensorflow graph?
If you want to keep all occurences of the maximum, you could use
cond = tf.equal(a, tf.reduce_max(a))
a_max = tf.where(cond, a, tf.zeros_like(a))
If you want to keep only one occurrence of the maximum, you could use
argmax = tf.argmax(a)
a_max = tf.scatter_nd([[argmax]], [a[argmax]], tf.to_int64(tf.shape(a)))
However according to the doc of tf.argmax,
Note that in case of ties the identity of the return value is not guaranteed
As I understand it, the maximum that is kept may not be the first or the last -- and may not even be the same if run twice on the same array.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With