Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Calling super's forward() method

What is the most appropriate way to call the forward() method of a parent Module? For example, if I subclass the nn.Linear module, I might do the following

class LinearWithOtherStuff(nn.Linear):
    def forward(self, x):
        y = super(Linear, self).forward(x)
        z = do_other_stuff(y)
        return z

However, the docs say not to call the forward() method directly:

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

which makes me think super(Linear, self).forward(x) could result in some unexpected errors. Is this true or am I misunderstanding inheritance?

like image 883
dkv Avatar asked Feb 18 '19 18:02

dkv


1 Answers

TLDR;

You can use super().forward(...) freely even with hooks and even with hooks registered in super() instance.

Explanation

As stated by this answer __call__ is here so the registered hooks (e.g. register_forward_hook) will be run.

If you inherit and want to reuse base class's forward, e.g. this:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

You are perfectly fine if you call __call__ method, forward won't run the hook (so you get 3 as above).

It is unlikely you would like to register_hook on the instance of super , but let's consider such example:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

You are perfectly fine using super().forward(...) and even hooks will work correctly (and that is the main idea of using __call__ instead of forward).

BTW. Calling super().__call__(...) would raise InifiniteRecursion error.

like image 90
Szymon Maszke Avatar answered Sep 30 '22 01:09

Szymon Maszke