What is the most appropriate way to call the forward()
method of a parent Module
? For example, if I subclass the nn.Linear
module, I might do the following
class LinearWithOtherStuff(nn.Linear):
def forward(self, x):
y = super(Linear, self).forward(x)
z = do_other_stuff(y)
return z
However, the docs say not to call the forward()
method directly:
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
which makes me think super(Linear, self).forward(x)
could result in some unexpected errors. Is this true or am I misunderstanding inheritance?
You can use super().forward(...)
freely even with hooks and even with hooks registered in super()
instance.
As stated by this answer __call__
is here so the registered hooks (e.g. register_forward_hook
) will be run.
If you inherit and want to reuse base class's forward
, e.g. this:
import torch
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
return super(Child, self).forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
You are perfectly fine if you call __call__
method, forward
won't run the hook (so you get 3
as above).
It is unlikely you would like to register_hook
on the instance of super
, but let's consider such example:
def increment_by_one(module, input, output):
return output + 1
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
# Increment by `1` from Parent
super().register_forward_hook(increment_by_one)
return super().forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1))) # and it is 5 indeed
print(module.forward(torch.tensor(1))) # here is 3
You are perfectly fine using super().forward(...)
and even hooks will work correctly (and that is the main idea of using __call__
instead of forward
).
BTW. Calling super().__call__(...)
would raise InifiniteRecursion
error.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With