Have tensor like :x.shape = [3, 2, 2]
.
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
I need to take .max()
over the 2nd and 3rd dimensions. I expect some like this [-0.2632, -0.1453, -0.0274]
as output. I tried to use: x.max(dim=(1,2))
, but this causes an error.
Returns the maximum value of all elements in the input tensor. Parameters: input (Tensor) – the input tensor. Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim .
max() and gradient. I understand the max operation is a not differentiable operation.
To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.
Now, you can do this. The PR was merged (Aug 28) and it is now available in the nightly release.
Simply use torch.amax()
:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(torch.amax(x, dim=(1, 2)))
# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])
Original Answer
As of today (April 11, 2020), there is no way to do .min()
or .max()
over multiple dimensions in PyTorch. There is an open issue about it that you can follow and see if it ever gets implemented. A workaround in your case would be:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(x.view(x.size(0), -1).max(dim=-1))
# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
So, if you need only the values: x.view(x.size(0), -1).max(dim=-1).values
.
If x
is not a contiguous tensor, then .view()
will fail. In this case, you should use .reshape()
instead.
Update August 26, 2020
This feature is being implemented in PR#43092 and the functions will be called amin
and amax
. They will return only the values. This is probably being merged soon, so you might be able to access these functions on the nightly build by the time you're reading this :) Have fun.
Although the solution of Berriel solves this specific question, I thought adding some explanation might help everyone to shed some light on the trick that's employed here, so that it can be adapted for (m)any other dimensions.
Let's start by inspecting the shape of the input tensor x
:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
So, we have a 3D tensor of shape (3, 2, 2)
. Now, as per OP's question, we need to compute maximum
of the values in the tensor along both 1st and 2nd dimensions. As of this writing, the torch.max()
's dim
argument supports only int
. So, we can't use a tuple. Hence, we will use the following trick, which I will call as,
The Flatten & Max Trick: since we want to compute max
over both 1st and 2nd dimensions, we will flatten both of these dimensions to a single dimension and leave the 0th dimension untouched. This is exactly what is happening by doing:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
So, now we have shrinked the 3D tensor to a 2D tensor (i.e. matrix).
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
Now, we can simply apply max
over the 1st dimension (i.e. in this case, first dimension is also the last dimension), since the flattened dimensions resides in that dimension.
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
We got 3 values in the resultant tensor since we had 3 rows in the matrix.
Now, on the other hand if you want to compute max
over 0th and 1st dimensions, you'd do:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
Now, we can simply apply max
over the 0th dimension since that is the result of our flattening. ((also, from our original shape of (3, 2, 2
), after taking max over first 2 dimensions, we should get two values as result.)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
In a similar vein, you can adapt this approach to multiple dimensions and other reduction functions such as min
.
Note: I'm following the terminology of 0-based dimensions (0, 1, 2, 3, ...
) just to be consistent with PyTorch usage and the code.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With