Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Binary mask from tf.nn.top_k indices for 4-D tensor in Tensorflow?

Tags:

tensorflow

I have a 4-D tensor of shape (10, 32, 32, 128). I want to generate a binary mask for all top N elements.

arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, N=64)

My question is how to get a binary mask of same shape as arr using the indices returned by tf.nn.top_k

like image 566
Ishant Mrinal Avatar asked Oct 30 '22 09:10

Ishant Mrinal


1 Answers

If someone is looking for the answer: here it goes.

K = 64
arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, k=K, sorted=False)

temp_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(
       tf.shape(arr)[:(arr.get_shape().ndims - 1)]) + [K])], indexing='ij')
temp_indices = tf.stack(temp_indices[:-1] + [indices], axis=-1)
full_indices = tf.reshape(temp_indices, [-1, arr.get_shape().ndims])
values = tf.reshape(values, [-1])

mask_st = tf.SparseTensor(indices=tf.cast(
      full_indices, dtype=tf.int64), values=tf.ones_like(values), dense_shape=arr.shape)
mask = tf.sparse_tensor_to_dense(tf.sparse_reorder(mask_st))
like image 106
Ishant Mrinal Avatar answered Nov 15 '22 11:11

Ishant Mrinal