Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

return the top_k masked softmax of each row for a 2D tensor

For any 2D tensor like

[[2,5,4,7], [7,5,6,8]],

I want to do softmax for the top k element in each row and then construct a new tensor by replacing all the other elements to 0.

The result should be to get the softmax of top k (here k=2) elements for each row [[7,5],[8,7]], which is thus [[0.880797,0.11920291], [0.7310586,0.26894143]] and then reconstruct a new tensor according to the index of the top k elements in the original tensor, the final result should be

[[0,0.11920291,0,0.880797], [0.26894143,0,0,0.7310586]].

Is it possible to implement this kind of masked softmax in tensorflow? Many thanks in advance!

like image 884
clement116 Avatar asked Nov 13 '18 12:11

clement116


1 Answers

Here is how you can do that:

import tensorflow as tf

# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)
# Test
with tf.Session() as sess:
    result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
    print(result_val)

Output:

[[0.         0.11920291 0.         0.880797  ]
 [0.26894143 0.         0.         0.7310586 ]]

EDIT:

Actually, there is a function that more closely does what you intend, tf.sparse.softmax. However, it requires a SparseTensor as input, and I'm not sure it should be faster since it has to figure out which sparse values go together in the softmax. The good thing about this function is that you could have different number of elements to softmax in each row, but in your case that does not seem to be important. Anyway, here is an implementation with that, in case you find it useful.

import tensorflow as tf

a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Flatten values
sparse_values = tf.reshape(a_top, [-1])
# Make sparse indices
shape = tf.cast(tf.shape(a), tf.int64)
a_row_idx = tf.tile(tf.range(shape[0])[:, tf.newaxis], (1, num_top))
sparse_idx = tf.stack([a_row_idx, tf.cast(a_top_idx, tf.int64)], axis=-1)
sparse_idx = tf.reshape(sparse_idx, [-1, 2])
# Make sparse tensor
a_top_sparse = tf.SparseTensor(sparse_idx, sparse_values, shape)
# Reorder sparse tensor
a_top_sparse = tf.sparse.reorder(a_top_sparse)
# Softmax
result_sparse = tf.sparse.softmax(a_top_sparse)
# Convert back to dense (or you can keep working with the sparse tensor)
result = tf.sparse.to_dense(result_sparse)
# Test
with tf.Session() as sess:
    result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
    print(result_val)
    # Same as before
like image 136
jdehesa Avatar answered Nov 05 '22 08:11

jdehesa