Pytorch network parameter calculation

Can someone tell me please about how the network parameter (10) is calculated? Thanks in advance.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()


  (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)

Most layer modules in PyTorch (e.g. Linear, Conv2d, etc.) group parameters into specific categories, such as weights and biases. Each of the five layer instances in your network has a "weight" and a "bias" parameter. This is why "10" is printed.

Of course, all of these "weight" and "bias" fields contain many parameters. For example, your first fully connected layer self.fc1 contains 16 * 5 * 5 * 120 = 48000 parameters. So len(params) doesn't tell you the number of parameters in the network--it gives you just the total number of "groupings" of parameters in the network.

Since Bill already answered why "10" is printed, I am just sharing a code snippet which you can use to find out the number of parameters associated with each layer in your network.

def count_parameters(model):
    total_param = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_param = numpy.prod(param.size())
            if param.dim() > 1:
                print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
                print(name, ':', num_param)
            total_param += num_param
    return total_param

Use the above function as follows.

print('number of trainable parameters =', count_parameters(net))


conv1.weight : 6x1x5x5 = 150
conv1.bias : 6
conv2.weight : 16x6x5x5 = 2400
conv2.bias : 16
fc1.weight : 120x400 = 48000
fc1.bias : 120
fc2.weight : 84x120 = 10080
fc2.bias : 84
fc3.weight : 10x84 = 840
fc3.bias : 10
number of trainable parameters = 61706
