Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Set k-largest elements of a tensor to zero in TensorFlow

Tags:

tensorflow

I want to find k largest elements of each row of h and set zero value to those maximum elements.

I could be able to select the indexes of top most value of each row by using top_k function like:

top_k = tf.nn.top_k(h, 1)

But I could not use the indexes returned by top_k to update tensor.

How can I do that? Thanks in advance...

like image 621
Turkdogan Tasdelen Avatar asked Jun 02 '16 13:06

Turkdogan Tasdelen


1 Answers

This is a bit tricky, maybe there is a better solution. tf.scatter_update() doesn't work here because it can only modify parts of tensor along the first dimension (not an element in first row and second column for instance).

You have to get the values and indices from tf.nn.top_k() to create a sparse Tensor and subtract it to the initial Tensor x:

x = tf.constant([[6., 2., 0.], [0., 4., 5.]])  # of type tf.float32

k = 2
values, indices = tf.nn.top_k(x, k, sorted=False)  # indices will be [[0, 1], [1, 2]], values will be [[6., 2.], [4., 5.]]

# We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]]
my_range = tf.expand_dims(tf.range(0, indices.get_shape()[0]), 1)  # will be [[0], [1]]
my_range_repeated = tf.tile(my_range, [1, k])  # will be [[0, 0], [1, 1]]

# change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2]
full_indices = tf.concat([tf.expand_dims(my_range_repeated, 2), tf.expand_dims(indices, 2)], axis=2)
full_indices = tf.reshape(full_indices, [-1, 2])

to_substract = tf.sparse_to_dense(full_indices, x.get_shape(), tf.reshape(values, [-1]), default_value=0.)

res = x - to_substract  # res should be all 0.
like image 98
Olivier Moindrot Avatar answered Oct 29 '22 21:10

Olivier Moindrot