I want to find k largest elements of each row of h and set zero value to those maximum elements.
I could be able to select the indexes of top most value of each row by using top_k function like:
top_k = tf.nn.top_k(h, 1)
But I could not use the indexes returned by top_k to update tensor.
How can I do that? Thanks in advance...
This is a bit tricky, maybe there is a better solution. tf.scatter_update()
doesn't work here because it can only modify parts of tensor along the first dimension (not an element in first row and second column for instance).
You have to get the values
and indices
from tf.nn.top_k()
to create a sparse Tensor and subtract it to the initial Tensor x
:
x = tf.constant([[6., 2., 0.], [0., 4., 5.]]) # of type tf.float32
k = 2
values, indices = tf.nn.top_k(x, k, sorted=False) # indices will be [[0, 1], [1, 2]], values will be [[6., 2.], [4., 5.]]
# We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]]
my_range = tf.expand_dims(tf.range(0, indices.get_shape()[0]), 1) # will be [[0], [1]]
my_range_repeated = tf.tile(my_range, [1, k]) # will be [[0, 0], [1, 1]]
# change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2]
full_indices = tf.concat([tf.expand_dims(my_range_repeated, 2), tf.expand_dims(indices, 2)], axis=2)
full_indices = tf.reshape(full_indices, [-1, 2])
to_substract = tf.sparse_to_dense(full_indices, x.get_shape(), tf.reshape(values, [-1]), default_value=0.)
res = x - to_substract # res should be all 0.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With