Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement my own ResNet with torch.nn.Sequential in Pytorch?

I want to implement a ResNet network (or rather, residual blocks) but I really want it to be in the sequential network form.

What I mean by sequential network form is the following:

## mdl5, from cifar10 tutorial
mdl5 = nn.Sequential(OrderedDict([
    ('pool1', nn.MaxPool2d(2, 2)),
    ('relu1', nn.ReLU()),
    ('conv1', nn.Conv2d(3, 6, 5)),
    ('pool1', nn.MaxPool2d(2, 2)),
    ('relu2', nn.ReLU()),
    ('conv2', nn.Conv2d(6, 16, 5)),
    ('relu2', nn.ReLU()),
    ('Flatten', Flatten()),
    ('fc1', nn.Linear(1024, 120)), # figure out equation properly
    ('relu4', nn.ReLU()),
    ('fc2', nn.Linear(120, 84)),
    ('relu5', nn.ReLU()),
    ('fc3', nn.Linear(84, 10))
]))

but of course with the NN lego blocks being “ResNet”.

I know the equation is something like:

enter image description here

but I am not sure how to do it in Pytorch AND Sequential. Sequential is key for me!


Bounty:

I'd like to see an example with a fully connected net and where the BN layers would have to go (and the drop out layers would go too). Ideally on a toy example/data if possible.


Cross-posted:

  • https://discuss.pytorch.org/t/how-to-have-residual-network-using-only-sequential-blocks/51541
  • https://www.quora.com/unanswered/How-does-one-implement-my-own-ResNet-with-torch-nn-Sequential-in-Pytorch
  • https://www.reddit.com/r/pytorch/comments/uyyr28/how_to_implement_my_own_resnet_with/
like image 323
Charlie Parker Avatar asked Jul 27 '19 04:07

Charlie Parker


1 Answers

You can't do it solely using torch.nn.Sequential as it requires operations to go, as the name suggests, sequentially, while yours are parallel.

You could, in principle, construct your own block really easily like this:

import torch

class ResNet(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs

Which one can use something like this:

model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 32, kernel_size=7),
    # 32 filters in and out, no max pooling so the shapes can be added
    ResNet(
        torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
        )
    ),
    # Another ResNet block, you could make more of them
    # Downsampling using maxpool and others could be done in between etc. etc.
    ResNet(
        torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(32, 32, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(32),
        )
    ),
    # Pool all the 32 filters to 1, you may need to use `torch.squeeze after this layer`
    torch.nn.AdaptiveAvgPool2d(1),
    # 32 10 classes
    torch.nn.Linear(32, 10),
)

Fact being usually overlooked (without real consequences when it comes to shallowe networks) is that skip connection should be left without any nonlinearities like ReLU or convolutional layers and that's what you can see above (source: Identity Mappings in Deep Residual Networks).

like image 168
Szymon Maszke Avatar answered Oct 27 '22 14:10

Szymon Maszke