Generally, a nn.Module
can be inherited by a subclass as below.
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform(m.weight) #
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.fc1 = nn.Linear(20, 1)
self.apply(init_weights)
def forward(self, x):
x = self.fc1(x)
return x
My 1st question is, why I can simply run the code below even my __init__
doesn't have any positinoal arguments for training_signals
and it looks like that training_signals
is passed to forward()
method. How does it work?
model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)
The second question is that how does self.apply(init_weights)
internally work? Is it executed before calling forward
method?
Q1: Why I can simply run the code below even my
__init__
doesn't have any positional arguments fortraining_signals
and it looks like thattraining_signals
is passed toforward()
method. How does it work?
First, the __init__
is called when you run this line:
model = LinearRegression()
As you can see, you pass no parameters, and you shouldn't. The signature of your __init__
is the same as the one of the base class (which you call when you run super(LinearRegression, self).__init__()
). As you can see here, nn.Module
's init signature is simply def __init__(self)
(just like yours).
Second, model
is now an object. When you run the line below:
model(training_signals)
You are actually calling the __call__
method and passing training_signals
as a positional parameter. As you can see here, among many other things, the __call__
method calls the forward
method:
result = self.forward(*input, **kwargs)
passing all parameters (positional and named) of the __call__
to the forward
.
Q2: How does
self.apply(init_weights)
internally work? Is it executed before calling forward method?
PyTorch is Open Source, so you can simply go to the source-code and check it. As you can see here, the implementation is quite simple:
def apply(self, fn):
for module in self.children():
module.apply(fn)
fn(self)
return self
Quoting the documentation of the function: it "applies fn
recursively to every submodule (as returned by .children()
) as well as self
". Based on the implementation, you can also understand the requirements:
fn
must be a callable;fn
receives as input only a Module
object;If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With