Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is my Fully Convolutional Autoencoder not symmetric?

I am developing a Fully Convolutional Autoencoder which takes 3 channels as input and outputs 2 channels (in: LAB, out: AB). Because the output should be the same size as the input, I use Full Convolution.

The Code:

import torch.nn as nn


class AE(nn.Module):
   def __init__(self):
       super(AE, self).__init__()

        self.encoder = nn.Sequential(
           # conv 1
           nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 2
           nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 3
           nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(256),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 4
           nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(512),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2, stride=2),

           # conv 5
           nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(1024),
           nn.ReLU()

       )

       self.decoder = nn.Sequential(
           # conv 6
           nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(512),
           nn.ReLU(),

           # conv 7
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(256),
           nn.ReLU(),

           # conv 8
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),

           # conv 9
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),

           # conv 10 out
           nn.Upsample(scale_factor=2, mode='bilinear'),
           nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=5, stride=1, padding=1),
           nn.Softmax()    # multi-class classification

           # TODO softmax deprecated
       )

   def forward(self, x):
       x = self.encoder(x)
       x = self.decoder(x)
       return x

The size the output tensor should be: torch.Size([1, 2, 199, 253])

The size the output tensor really has: torch.Size([1, 2, 190, 238])

My main problem is combining Conv2d and MaxPool2d and to set the correct parameter values in the ConvTranspose2d. Because of that, I treat those separately using the Upsample function for the MaxPool2d and ConvTranspose2d only for Conv2d. But I still have a little asymmetry and I really don't know why.

Thank you for the help!

like image 461
Stefan Avatar asked Oct 18 '25 09:10

Stefan


1 Answers

There are two issues.

First is insufficient padding: with kernel_size=5 your convolutions are shrinking the image by 4 every time they are applied (2 pixels on each side), so you need padding=2, and not just 1, in all places.

Second is the "uneven" input size. What I mean is that once your convolutions are properly padded, you are left with downsampling operations which at each point try to divide your image resolution in half. When they fail, they just return a smaller result (integer division discards the remainder). Since your network has 4 successive 2x downsampling operations, you need your input to have H, W dimensions which are multiples of 2^4=16. Then you will actually get equally shaped output. An example below

import torch
import torch.nn as nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        self.encoder = nn.Sequential(
            # conv 1
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 4
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # conv 5
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            # conv 6
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            # conv 7
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            # conv 8
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            # conv 9
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            # conv 10 out
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=5, stride=1, padding=2),
            nn.Softmax()    # multi-class classification
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

input = torch.randn(1, 3, 6*16, 7*16)
output = AE()(input)
print(input.shape)
print(output.shape)
like image 107
Jatentaki Avatar answered Oct 20 '25 00:10

Jatentaki



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!