Say I have a tensor of size BxWxHxD. I want to process the tensor such that I have a new BxWxHxD tensor where only the maximum element in each WxH slice is kept, and all other values are zero.
In other words, I think the best way to achieve this is to somehow take a 2D argmax across the WxH slices, resulting in BxD index tensors for the rows and colums that can then be converted to a one-hot BxWxHxD tensor to be used as a mask. How do I make this work?
You can use the following function as a starting point. It calculates the indexes of the maximum element for each batch and for each channel. The resulting array is in the format (batch size, 2, number of channels).
def argmax_2d(tensor):
# input format: BxHxWxD
assert rank(tensor) == 4
# flatten the Tensor along the height and width axes
flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))
# argmax of the flat tensor
argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)
# convert indexes into 2D coordinates
argmax_x = argmax // tf.shape(tensor)[2]
argmax_y = argmax % tf.shape(tensor)[2]
# stack and return 2D coordinates
return tf.stack((argmax_x, argmax_y), axis=1)
def rank(tensor):
# return the rank of a Tensor
return len(tensor.get_shape())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With