Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

AttributeError: 'torch.return_types.max' object has no attribute 'dim' - Maxpooling Channel

I'm trying to do maxpooling over channel dimension:

class ChannelPool(nn.Module):
    def forward(self, input):
        return torch.max(input, dim=1)

but I get the error

AttributeError: 'torch.return_types.max' object has no attribute 'dim'
like image 529
Guglie Avatar asked Mar 25 '20 10:03

Guglie


2 Answers

The torch.max function called with dim returns a tuple so:

class ChannelPool(nn.Module):
    def forward(self, input):
        input_max, max_indices = torch.max(input, dim=1)
        return input_max

From the documentation of torch.max:

Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. And indices is the index location of each maximum value found (argmax).

like image 67
Guglie Avatar answered Nov 14 '22 01:11

Guglie


I met the same err recently. There are 2 form of torch.max().

  • if you just give a input tensor (without other args like dim ...), max() function will return a tensor

  • if you specify other args (for example dim=0), max() function will returns a namedtuple: (values, indices). I guess the values is what you want.

like image 25
ZhaoYi Avatar answered Nov 14 '22 01:11

ZhaoYi