Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why do we pass nn.Module as an argument to class definition for neural nets?

I want to understand why we pass torch.nn.Module as a argument when we define the class for a neural network like GAN's

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f
like image 335
Franklin Varghese Avatar asked Jun 01 '19 09:06

Franklin Varghese


People also ask

What does nn module do?

nn module it uses method like forward(input) which returns the output. A simple neural network takes input to add weights and bias to it feed the input through multiple hidden layers and finally returns the output.

What does PyTorch nn do?

PyTorch provides the torch. nn module to help us in creating and training of the neural network. We will first train the basic neural network on the MNIST dataset without using any features from these models.

What is the difference between nn ModuleList and nn sequential?

nn. ModuleList is just a Python list (though it's useful since the parameters can be discovered and trained via an optimizer). While nn. Sequential is a module that sequentially runs the component on the input.

What is use of nn sequential?

So nn. Sequential is a construction which is used when you want to run certain layers sequentially. It makes the forward to be readable and compact.


2 Answers

This line

class Generator(nn.Module):

simple means the Generator class will inherit the nn.Module class, it is not an argument.

However, the dunder init method:

def __init__(self, input_size, hidden_size, output_size, f):

Has self which is why you may consider this as an argument.

Well this is Python class instance self. There were tinkering battles should it stay or should it go, but Guido, explained in his blog why it has to stay.

like image 93
prosti Avatar answered Sep 29 '22 14:09

prosti


We are essentially defining the class 'Generator' with the nn.Module (with its functionalities). In programming we refer to this as inheritence (with the super(Generator, self).__init__()).

You can read more here: https://www.w3schools.com/python/python_inheritance.asp

like image 38
Nicolas Essipova Avatar answered Sep 29 '22 13:09

Nicolas Essipova