Why does super(LR, self).__init__()
need to be called in the code below? I get the error "AttributeError: cannot assign module before Module.init() call" otherwise. That error is caused by self.linear = nn.Linear(input_size, output_size)
.
I don't understand what the connection is between calling super(LR, self).__init__()
and being able to assign the nn.Linear object to self.linear. nn.Linear is a separate object which can be assigned to a variable outside of any class, so why does super(LR, self).__init__()
need to be called to assign a Linear object to self.linear within the class?
class LR(nn.Module):
# Constructor
def __init__(self, input_size, output_size):
# Inherit from parent
super(LR, self).__init__()
self.test = 1
self.linear = nn.Linear(input_size, output_size)
# Prediction function
def forward(self, x):
out = self.linear(x)
return out
When you write self.linear = nn.Linear(...)
inside your custom class, you are actually calling the __setattr__
function of your class. It just happens that when you extend nn.Module
, there are a bunch of things that your class is inheriting, and one of them is the __setattr__
. As you can see in the implementation (I post only the relevant part below), if nn.Linear
is an instance of nn.Module
, your class must have an attribute called _modules
, otherwise it will throw the AttributeError
you got:
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
# [...]
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
If you take a look at the nn.Module
's __init__
, you'll see that self._modules
is initialized there:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict() # <---- here
The same is true for buffers and parameters.
you need the super() call so that the mn.Module class itself is initialised. IN Python superclass constructors/initialisers aren't called automatically - they have to be called explicitly, and that is what super() does - it works out what superclass to call.
I assume you are using Python 3 - in which case you don't need the arguments in the super() call - this is sufficient :
super().__init__()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With