Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Tensorflow, how to unravel the flattened indices obtained by tf.nn.max_pool_with_argmax?

I meet a problem: After I use the tf.nn.max_pool_with_argmax, I obtain the indices i.e. argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.

How to unravel the flattened indices back to the coordinates list in Tensorflow?

Thank you very much.

like image 343
karl_TUM Avatar asked Jun 07 '16 14:06

karl_TUM


1 Answers

I had the same problem today and I ended up with this solution:

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.pack(output_list)

Here is an usage example in an ipython notebook (I use it to forward the pooling argmax positions to my unpooling method)

like image 62
Fabian Avatar answered Nov 15 '22 10:11

Fabian