Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow indicator matrix for top n values

Does anyone know how to extract the top n largest values per row of a rank 2 tensor?

For instance, if I wanted the top 2 values of a tensor of shape [2,4] with values:

[[40, 30, 20, 10], [10, 20, 30, 40]]

The desired condition matrix would look like: [[True, True, False, False],[False, False, True, True]]

Once I have the condition matrix, I can use tf.select to choose actual values.

Thank you for assistance!

like image 353
Terry Rabinowitz Avatar asked Nov 25 '16 16:11

Terry Rabinowitz


2 Answers

You can do it using built-in tf.nn.top_k function:

a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]])
b = tf.nn.top_k(a, 2)

print(sess.run(b))
TopKV2(values=array([[40, 30],
   [40, 30]], dtype=int32), indices=array([[0, 1],
   [3, 2]], dtype=int32))

print(sess.run(b).values))
array([[40, 30],
       [40, 30]], dtype=int32)

To get boolean True/False values, you can first get the k-th value and then use tf.greater_equal:

kth = tf.reduce_min(b.values)
top2 = tf.greater_equal(a, kth)
print(sess.run(top2))
array([[ True,  True, False, False],
       [False, False,  True,  True]], dtype=bool)
like image 104
sygi Avatar answered Nov 13 '22 22:11

sygi


you can also use tf.contrib.framework.argsort

a = [[40, 30, 20, 10], [10, 20, 30, 40]]
idx = tf.contrib.framework.argsort(a, direction='DESCENDING')  # sorted indices
ranks = tf.contrib.framework.argsort(idx, direction='ASCENDING')  # ranks
b = ranks < 2  
# [[ True  True False False] [False False  True  True]]

Moreover, you can replace 2 with a 1d tensor so that each row/column can have different n values.

like image 25
dontloo Avatar answered Nov 13 '22 22:11

dontloo