Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why aren't torch.nn.Parameter listed when net is printed?

Tags:

python

pytorch

I recently had to construct a module that required a tensor to be included. While back propagation worked perfectly using torch.nn.Parameter, it did not show up when printing the net object. Why isn't this parameter included in contrast to other modules like layer? (Shouldn't it behave just like layer?)

import torch
import torch.nn as nn

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.layer = nn.Linear(10, 10)
        self.parameter = torch.nn.Parameter(torch.zeros(10,10, requires_grad=True))

net = MyNet()
print(net)

Output:

MyNet(
  (layer): Linear(in_features=10, out_features=10, bias=True)
)
like image 286
flawr Avatar asked Feb 19 '19 15:02

flawr


People also ask

What is Torch nn parameter?

class torch.nn.parameter. Parameter (data=None, requires_grad=True)[source] A kind of Tensor that is to be considered a module parameter.

What is model parameters () PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

How does torch nn linear work?

PyTorch - nn.Linear Linear(n,m) is a module that creates single layer feed forward network with n inputs and m output. Mathematically, this module is designed to calculate the linear equation Ax = b where x is input, b is output, A is weight.

What is register buffer in PyTorch?

Registering attributes using PyTorch's register_buffer​PyTorch allows subclasses of nn.Module to register a buffer in an object using self.register_buffer("foo", initial_value) . Pyre supports this pattern when used within the constructor. It simply treats the buffer as a Tensor attribute of the class: import torch.


2 Answers

When you call print(net), the __repr__ method is called. __repr__ gives the “official” string representation of an object.

In PyTorch's nn.Module (base class of your MyNet model), the __repr__ is implemented like this:

def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str

Note that the above method returns main_str which contains call to only _modules and extra_repr, thus it prints only modules by default.


PyTorch also provides extra_repr() method which you can implement yourself for extra representation of the module.

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

like image 52
kHarshit Avatar answered Sep 27 '22 22:09

kHarshit


According to nn.Parameter docs:

Parameters are :class:~torch.Tensor subclasses, that have a very special property when used with :class:Module s - when they're assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in :meth:~Module.parameters iterator.

So you can find it in net.parameters. Let's look at the following example:

Code:

import torch
import torch.nn as nn

torch.manual_seed(42)

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.layer = nn.Linear(4, 4)
        self.parameter = nn.Parameter(torch.zeros(4, 4, requires_grad=True))
        self.tensor = torch.ones(4, 4)
        self.module = nn.Module()

net = MyNet()
print(net)

Output:

MyNet(
  (layer): Linear(in_features=4, out_features=4, bias=True)
  (module): Module()
)

As you can see, there is no tensor or 'parameter' object (because parameter is subclass of tensor), only Modules.

Now let's try to get our net parameters:

Code:

for p in net.parameters():
    print(p)

Output:

Parameter containing:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], requires_grad=True)
Parameter containing:
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936],
        [ 0.4408, -0.3668,  0.4346,  0.0936],
        [ 0.3694,  0.0677,  0.2411, -0.0706]], requires_grad=True)
Parameter containing:
tensor([ 0.3854,  0.0739, -0.2334,  0.1274], requires_grad=True)

Ok, so the first one is your net.parameter. Next two is weights and bias of net.layer. Let's verify it:

Code:

print(net.layer.weight)
print(net.layer.bias)

Output:

Parameter containing:
tensor([[ 0.3823,  0.4150, -0.1171,  0.4593],
        [-0.1096,  0.1009, -0.2434,  0.2936],
        [ 0.4408, -0.3668,  0.4346,  0.0936],
        [ 0.3694,  0.0677,  0.2411, -0.0706]], requires_grad=True)
Parameter containing:
tensor([ 0.3854,  0.0739, -0.2334,  0.1274], requires_grad=True)
like image 30
trsvchn Avatar answered Sep 27 '22 23:09

trsvchn