Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I take the 2nd max from a Tensorflow tensor?

Right now, my function uses argmax:

 p = tf.stop_gradient(tf.argmax(prev, 1))

I have tried using the following, but the dimn are incompatible:

 p = tf.stop_gradient(tf.nn.top_k(prev, 2)[1])

 raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
 ValueError: Linear is expecting 2D arguments: [[None, 2, 1024], [None, 1024]]

My TF version might be 0.5, which is why top_k only has 2 args.

like image 685
user2827214 Avatar asked Apr 07 '16 22:04

user2827214


1 Answers

Check the documentation for tf.nn.top_k(). The function returns values and indices. So something like below should work.

values, indices = tf.nn.top_k(prev,2)
p = tf.stop_gradient(indices[1])
like image 123
shekkizh Avatar answered Oct 10 '22 03:10

shekkizh