Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is the super constructor necessary in PyTorch custom modules?

Why does super(LR, self).__init__() need to be called in the code below? I get the error "AttributeError: cannot assign module before Module.init() call" otherwise. That error is caused by self.linear = nn.Linear(input_size, output_size).

I don't understand what the connection is between calling super(LR, self).__init__() and being able to assign the nn.Linear object to self.linear. nn.Linear is a separate object which can be assigned to a variable outside of any class, so why does super(LR, self).__init__() need to be called to assign a Linear object to self.linear within the class?

class LR(nn.Module):
    
    # Constructor
    def __init__(self, input_size, output_size):
        
        # Inherit from parent
        super(LR, self).__init__()
        self.test = 1
        self.linear = nn.Linear(input_size, output_size)
        
    
    # Prediction function
    def forward(self, x):
        out = self.linear(x)
        return out
like image 484
chmod_777 Avatar asked Jul 23 '20 15:07

chmod_777


2 Answers

When you write self.linear = nn.Linear(...) inside your custom class, you are actually calling the __setattr__ function of your class. It just happens that when you extend nn.Module, there are a bunch of things that your class is inheriting, and one of them is the __setattr__. As you can see in the implementation (I post only the relevant part below), if nn.Linear is an instance of nn.Module, your class must have an attribute called _modules, otherwise it will throw the AttributeError you got:

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
    # [...]
    modules = self.__dict__.get('_modules')
    if isinstance(value, Module):
        if modules is None:
            raise AttributeError("cannot assign module before Module.__init__() call")
        remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
        modules[name] = value

If you take a look at the nn.Module's __init__, you'll see that self._modules is initialized there:

def __init__(self):
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._non_persistent_buffers_set = set()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict()                         # <---- here

The same is true for buffers and parameters.

like image 103
Berriel Avatar answered Oct 13 '22 03:10

Berriel


you need the super() call so that the mn.Module class itself is initialised. IN Python superclass constructors/initialisers aren't called automatically - they have to be called explicitly, and that is what super() does - it works out what superclass to call.

I assume you are using Python 3 - in which case you don't need the arguments in the super() call - this is sufficient :

        super().__init__()
like image 5
Tony Suffolk 66 Avatar answered Oct 13 '22 01:10

Tony Suffolk 66