Right now, my function uses argmax:
p = tf.stop_gradient(tf.argmax(prev, 1))
I have tried using the following, but the dimn are incompatible:
p = tf.stop_gradient(tf.nn.top_k(prev, 2)[1])
raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
ValueError: Linear is expecting 2D arguments: [[None, 2, 1024], [None, 1024]]
My TF version might be 0.5, which is why top_k only has 2 args.
Check the documentation for tf.nn.top_k(). The function returns values and indices. So something like below should work.
values, indices = tf.nn.top_k(prev,2)
p = tf.stop_gradient(indices[1])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With