Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Convolutional Autoencoders

How one construct decoder part of convolutional autoencoder? Suppose I have this

(input -> conv2d -> maxpool2d -> maxunpool2d -> convTranspose2d -> output):

# CIFAR images shape = 3 x 32 x 32

class ConvDAE(nn.Module):
    def __init__(self):
        super().__init__()

        # input: batch x 3 x 32 x 32 -> output: batch x 16 x 16 x 16
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1), # batch x 16 x 32 x 32
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2) # batch x 16 x 16 x 16
        )

        # input: batch x 16 x 16 x 16 -> output: batch x 3 x 32 x 32
        self.decoder = nn.Sequential(
            # this line does not work
            # nn.MaxUnpool2d(2, stride=2, padding=0), # batch x 16 x 32 x 32
            nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1), # batch x 16 x 32 x 32
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1, output_padding=0), # batch x 3 x 32 x 32
            nn.ReLU()
        )

    def forward(self, x):
        print(x.size())
        out = self.encoder(x)
        print(out.size())
        out = self.decoder(out)
        print(out.size())
        return out

Pytorch specific question: why can't I use MaxUnpool2d in decoder part. This gives me the following error:

TypeError: forward() missing 1 required positional argument: 'indices'

And the conceptual question: Shouldn't we do in decoder inverse of whatever we did in encoder? I saw some implementations and it seems they only care about the dimensions of input and output of decoder. Here and here are some examples.

like image 224
torayeff Avatar asked Oct 24 '25 03:10

torayeff


1 Answers

For the torch part of the question, unpool modules have as a required positional argument the indices returned from the pooling modules which will be returned with return_indices=True. So you could do

class ConvDAE(nn.Module):
    def __init__(self):
        super().__init__()

        # input: batch x 3 x 32 x 32 -> output: batch x 16 x 16 x 16
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1), # batch x 16 x 32 x 32
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2, return_indices=True)
        )

        self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)

        self.decoder = nn.Sequential( 
            nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1, output_padding=0), 
            nn.ReLU()
        )

    def forward(self, x):
        print(x.size())
        out, indices = self.encoder(x)
        out = self.unpool(out, indices)
        out = self.decoder(out)
        print(out.size())
        return out

As for the general part of the question, I don't think state of the art is to use a symmetric decoder part, as it has been shown that devonvolution/transposed convolution produces checkerboard effects and many approaches tend to use upsampling modules instead. You will find more info faster through PyTorch channels.

like image 127
PSL Avatar answered Oct 25 '25 18:10

PSL



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!