Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does dim=1 return row indices in torch.argmax?

I am working on argmax function of PyTorch which is defined as:

torch.argmax(input, dim=None, keepdim=False)

Consider an example

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])

As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.

like image 695
Programmer Avatar asked Apr 15 '19 14:04

Programmer


People also ask

What does torch argmax return?

argmax. Returns the indices of the maximum value of all elements in the input tensor. This is the second value returned by torch.

How does torch argmax work for 4 dimensions?

Argmax gives the index corresponding to highest value across a given dimension. so the number of dimensions is not an issue. so when you apply argmax across the given dimension, PyTorch by default collapses that dimension since its values are replaced by a single index.

How does torch argmax work?

Torch. argmax() method accepts a tensor and returns the indices of the maximum values of the input tensor across a specified dimension/axis. If the input tensor exists with multiple maximal values then the function will return the index of the first maximal element.

How do you find the max index of a torch?

To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.


1 Answers

It's time to correctly understand how the axis or dim argument work in PyTorch:

tensor dimension


The following example should make sense once you comprehend the above picture:

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])

Note: dim (short for 'dimension') is the torch equivalent of 'axis' in NumPy.

like image 156
kmario23 Avatar answered Sep 16 '22 21:09

kmario23