Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch custom layer "is not a Module subclass"

I am new to PyTorch, trying it out after using a different toolkit for a while.

I would like understand how to program custom layers and functions. And as a simple test, I wrote this:

class Testme(nn.Module):         ## it _is_ a sublcass of module ##
    def __init__(self):
        super(Testme, self).__init__()

    def forward(self, x):
        return x / t_.max(x)

which is intended to cause the data passing through it to sum to 1. Not actually useful, just at test.

Then I plug it to the example code from the PyTorch Playground:

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for i, v in enumerate(cfg):
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            padding = v[1] if isinstance(v, tuple) else 1
            out_channels = v[0] if isinstance(v, tuple) else v
            conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(out_channels, affine=False), nn.ReLU()]
            else:
                layers += [conv2d, nn.ReLU()]
            layers += [Testme]                           # here <------------------
            in_channels = out_channels
    return nn.Sequential(*layers)

The result is an error!

TypeError: model.Testme is not a Module subclass

Maybe this needs to be a Function rather than a Module? Also not clear what the difference is between Function, Module.

For example, why does a Function need a backward(), even if it is constructed entirely from standard pytorch primitive, whereas a Module does not need this?

like image 709
forgotmysocks Avatar asked Jun 07 '17 07:06

forgotmysocks


1 Answers

That's a simple one. You almost got it, but you forgot to actually create an instance of your new class Testme. You need to do this, even if the creation of an instance of a particular class doesn't take any parameters (as for Testme). But it's easier to forget than for a convolutional layer, to which you typically pass a lot of arguments.

Change the line you have indicated to the following and your problem is resolved.

layers += [Testme()]
like image 53
mbpaulus Avatar answered Sep 22 '22 21:09

mbpaulus