Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch runtime error : invalid argument 0: Sizes of tensors must match except in dimension 1

I have a PyTorch model and I'm trying to test it by performing a forward pass. Here is the code:

class ResBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(ResBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        #batch normalization
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.stride = stride

    def forward(self, x):
        residual = self.conv1x1(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        #adding the skip connection
        out += residual
        out = self.relu(out)

        return out

class ResUnet (nn.Module):

    def __init__(self, in_shape,  num_classes):
        super(ResUnet, self).__init__()
        in_channels, height, width = in_shape
        #
        #self.L1 = IncResBlock(in_channels,64)
        self.e1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2,padding=1),
            ResBlock(64,64))


        self.e2 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(64, 128, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(128),
            ResBlock(128,128))
        #
        self.e2add = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128))
        #
        ##
        self.e3 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(128,256, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(256),
            ResBlock(256,256))

        self.e4 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(256,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        self.e4add = nn.Sequential(
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512)) 
        #
        self.e5 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        #
        self.e6 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512), 
            ResBlock(512,512))
        #
        self.e6add = nn.Sequential(
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512)) 
        #
        self.e7 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        self.e8 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1))
            #nn.BatchNorm2d(512))

        self.d1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d2 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d3 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d4 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))

        #
        self.d5 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(256),
            ResBlock(256,256))
        #
        self.d6 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(128),
            ResBlock(128,128))
        #
        self.d7 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(64),
            ResBlock(64,64))
        #
        self.d8 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1))
            #nn.BatchNorm2d(64),
            #nn.ReLU())

        self.out_l = nn.Sequential(
            nn.Conv2d(64,num_classes,kernel_size=1,stride=1))
            #nn.ReLU())

    def forward(self, x):

        #Image Encoder

        #### Encoder #####

        en1 = self.e1(x)

        en2 = self.e2(en1)
        en2add = self.e2add(en2)

        en3 = self.e3(en2add)

        en4 = self.e4(en3)
        en4add = self.e4add(en4)

        en5 = self.e5(en4add)

        en6 = self.e6(en5)
        en6add = self.e6add(en6)

        en7 = self.e7(en6add)

        en8 = self.e8(en7)

        #### Decoder ####
        de1_ = self.d1(en8)
        de1 = torch.cat([en7,de1_],1)

        de2_ = self.d2(de1)
        de2 = torch.cat([en6add,de2_],1)


        de3_ = self.d3(de2)
        de3 = torch.cat([en5,de3_],1)


        de4_ = self.d4(de3)
        de4 = torch.cat([en4add,de4_],1)


        de5_ = self.d5(de4)
        de5 = torch.cat([en3,de5_],1)

        de6_ = self.d6(de5)
        de6 = torch.cat([en2add,de6_],1)

        de7_ = self.d7(de6)
        de7 = torch.cat([en1,de7_],1)
        de8 = self.d8(de7)

        out_l_mask = self.out_l(de8)

        return out_l_mask  

Here is how I attempt to test it:

modl = ResUnet((1,512,512), 1)
x = torch.rand(1, 1, 512, 512)
modl(x)

This works fine, as does for any size that are multiples of 64.

If I try:

modl = ResUnet((1,320,320), 1)
x = torch.rand(1, 1, 320, 320)
modl(x)

It throws an error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-46-4ddc821c365b> in <module>
----> 1 modl(x)

~/.conda/envs/torch0.4/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

<ipython-input-36-f9eeefa3c0b8> in forward(self, x)
    221         de2_ = self.d2(de1)
    222         #print de2_.size()
--> 223         de2 = torch.cat([en6add,de2_],1)
    224         #print de2.size()
    225 

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 5 and 4 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616

I figure the problem is caused by the input size not being a power of 2 but I am not sure how to rectify it for the given input dimenstions (320, 320).

like image 320
Reptilian Avatar asked Oct 27 '25 17:10

Reptilian


1 Answers

This issue arises from mismatch in size between the variables in the downsampling (encoder) path and the upsampling (decoder) path. Your code is huge and difficult to understand, but by inserting print statements, we can check that

  1. en6add is of size [1, 512, 5, 5]
  2. en7 is [1, 512, 2, 2]
  3. en8 is [1, 512, 1, 1]
  4. then upsampling goes as powers of two: de1_ is [1, 512, 2, 2]
  5. de1 [1, 1024, 2, 2]
  6. de2_ [1, 512, 4, 4]

at which point you try to concatenate it with en6add, so apparently the code creating de2_ is not "upsampling enough". My strong guess is that you need to pay the attention to the output_padding parameter of nn.ConvTranspose2d and possibly set it to 1 in a couple of places. I would try and fix this error for you, but that example is so far from being minimal that I can't wrap my head around the whole of it.

like image 122
Jatentaki Avatar answered Oct 30 '25 06:10

Jatentaki



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!