Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch skip connection in a sequential model

I am trying to wrap my head around skip connections in a sequential model. With the functional API I would be doing something as easy as (quick example, maybe not be 100% syntactically correct but should get the idea):

x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)

x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)

I am now using a sequential model and trying to do something similar, create a skip connection that brings the activations of the first conv layer all the way to the last convTranspose. I have taken a look at the U-net architecture implemented here and it's a bit confusing, it does something like this:

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
    model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
    model = down + [submodule] + up

Isn't this just adding layers to the sequential model well, sequentially? There is the down conv which is followed by submodule (which recursively adds inner layers) and then concatenated to up which is the upconv layer. I am probably missing something important on how the Sequential API works, but how does the code snipped from U-NET actually implements the skip?

like image 516
powder Avatar asked Aug 09 '18 17:08

powder


People also ask

How do I add a skip connection in PyTorch?

The skip layer is simply done by concatenating the input x and the (recursive) block output self. model(x) , with self. model the list of operations you mentioned -- so not so differently from the Functional code you wrote.

Why skip connections improve performance?

Skip Connections were introduced to solve different problems in different architectures. In the case of ResNets, skip connections solved the degradation problem that we addressed earlier whereas, in the case of DenseNets, it ensured feature reusability.

How do skip connections work?

Skip connections in deep architectures, as the name suggests, skip some layer in the neural network and feeds the output of one layer as the input to the next layers (instead of only the next one). As previously explained, using the chain rule, we must keep multiplying terms with the error gradient as we go backwards.


1 Answers

Your observations are correct, but you may have missed the definition of UnetSkipConnectionBlock.forward() (UnetSkipConnectionBlock being the Module defining the U-Net block you shared), which may clarify this implementation:

(from pytorch-CycleGAN-and-pix2pix/models/networks.py#L259)

# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):

    # ...

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

The last line is the key (applied for all inner blocks). The skip layer is simply done by concatenating the input x and the (recursive) block output self.model(x), with self.model the list of operations you mentioned -- so not so differently from the Functional code you wrote.

like image 173
benjaminplanche Avatar answered Oct 04 '22 21:10

benjaminplanche