Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow multi-dimension argmax

Tags:

tensorflow

Say I have a tensor of size BxWxHxD. I want to process the tensor such that I have a new BxWxHxD tensor where only the maximum element in each WxH slice is kept, and all other values are zero.

In other words, I think the best way to achieve this is to somehow take a 2D argmax across the WxH slices, resulting in BxD index tensors for the rows and colums that can then be converted to a one-hot BxWxHxD tensor to be used as a mask. How do I make this work?

like image 645
Rasmus Avatar asked Apr 03 '16 17:04

Rasmus


1 Answers

You can use the following function as a starting point. It calculates the indexes of the maximum element for each batch and for each channel. The resulting array is in the format (batch size, 2, number of channels).

def argmax_2d(tensor):

  # input format: BxHxWxD
  assert rank(tensor) == 4

  # flatten the Tensor along the height and width axes
  flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))

  # argmax of the flat tensor
  argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)

  # convert indexes into 2D coordinates
  argmax_x = argmax // tf.shape(tensor)[2]
  argmax_y = argmax % tf.shape(tensor)[2]

  # stack and return 2D coordinates
  return tf.stack((argmax_x, argmax_y), axis=1)

def rank(tensor):

  # return the rank of a Tensor
  return len(tensor.get_shape())
like image 176
ondrejba Avatar answered Nov 17 '22 12:11

ondrejba