Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does pytorch's nn.Module register submodule?

Tags:

python

pytorch

When I read the source code(python) of torch.nn.Module , I found the attribute self._modules has been used in many functions like self.modules(), self.children(), etc. However, I didn't find any functions updating it. So, where will the self._modules be updated? Furthermore, how does pytorch's nn.Module register submodule?

class Module(object):
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

    def named_modules(self, memo=None, prefix=''):
        if memo is None:
            memo = set()
        if self not in memo:
            memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix):
                    yield m
like image 483
JK.song Avatar asked Mar 05 '19 02:03

JK.song


1 Answers

Add some details to Jiren Jin's answer:

  • Layers of a net (inherited from nn.Module) are stored in Module._modules, which is initialized in __construct:

    def __init__(self):
        self.__construct()
        # initialize self.training separately from the rest of the internal
        # state, as it is managed differently by nn.Module and ScriptModule
        self.training = True
    
    def __construct(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        # ...
        self._modules = OrderedDict()
    
  • self._modules is updated in __setattr__. __setattr__(obj, name, value) is called when obj.name = value is executed. For example, if one defines self.conv1 = nn.Conv2d(128, 256, 3, 1, 1) when initializing a net inherited from nn.Module, the following code from nn.Module.__setattr__ will be executed:

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            # ...
        elif params is not None and name in params:
            # ...
        else:
            modules = self.__dict__.get('_modules') # equivalent to modules = self._modules
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                # register the given layer (nn.Conv2d) with its name (conv1)
                # equivalent to self._modules['conv1'] = nn.Conv2d(128, 256, 3, 1, 1)
                modules[name] = value
    

Question from comments:

Do you know how this works with the fact that torch lets you supply your own forward method?

If one runs a forward pass of a net inherited from nn.Module, the nn.Module.__call__ will be called, in which self.forward is called. However, one has overrided the forward when implementing the net.

like image 70
Yuen Tau Avatar answered Oct 04 '22 16:10

Yuen Tau