Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I pass a keyword argument to the forward used by a pre-forward hook?

Given a torch's nn.Module with a pre-forward hook, e.g.

import torch
import torch.nn as nn

class NeoEmbeddings(nn.Embedding):
    def __init__(self, num_embeddings:int, embedding_dim:int, padding_idx=-1):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(self.neo_genesis)

    @staticmethod
    def neo_genesis(self, input, higgs_bosson=0):
        if higgs_bosson:
            input = input + higgs_bosson
        return input

It's possible to let an input tensor go through some manipulation before going to the actual forward() function, e.g.

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]))
tensor([[-1.6449,  0.5832, -0.0165, -1.3329,  0.6878],
        [-0.3262,  0.5844,  0.6917,  0.1268,  2.1363],
        [ 1.0772,  0.1748, -0.7131,  0.7405,  1.5733],
        [ 0.7651,  0.4619,  0.4388, -0.2752, -0.3018]],
       grad_fn=<EmbeddingBackward>)

>>> print(x._forward_pre_hooks)
OrderedDict([(25, <function NeoEmbeddings.neo_genesis at 0x1208d10d0>)])

How could we pass the arguments (*args or **kwargs) that the pre-forward hook needs but not accepted by the default forward() function?

Without modification/overriding the forward() function, this is not possible:

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

----------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-102-8705a40a3cc2> in <module>
      1 x = NeoEmbeddings(10, 5, 1)
----> 2 x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

TypeError: forward() got an unexpected keyword argument 'higgs_bosson'
like image 900
alvas Avatar asked Aug 29 '19 05:08

alvas


People also ask

What is register forward hook?

Forward and Backward Function Hooks You can register a function on a Module or a Tensor . The hook can be a forward hook or a backward hook. The forward hook will be executed when a forward call is executed. The backward hook will be executed in the backward phase. Let's look at an example.

What is model parameters () PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

What is forward function in PyTorch?

PyTorch: Defining new autograd functions The forward function computes output Tensors from input Tensors. The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value.

What is a hook in machine learning?

A hook is like a one of those devices that many heroes leave behind in the villain's den to get all the information. You can register a hook on a Tensor or a nn. Module . A hook is basically a function that is executed when the either forward or backward is called.


2 Answers

Torchscript incompatible (as of 1.2.0)

First of all, your example torch.nn.Module has some minor mistakes (probably by an accident).

Secondly, you can pass anything to forward and register_forward_pre_hook will just get the argument that will be passed your your torch.nn.Module (be it layer or model or anything) else. You indeed cannot do it without modifying forward call, but why would you want to avoid that? You could simply forward the arguments to base function as can be seen below:

import torch


class NeoEmbeddings(torch.nn.Embedding):
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

    # First argument should be named something like module, as that's what 
    # you are registering this hook to
    @staticmethod
    def neo_genesis(module, inputs):  # No need for self as first argument
        net_input, higgs_bosson = inputs  # Simply unpack tuple here
        return net_input

    def forward(self, inputs, higgs_bosson):
        # Do whatever you want here with both arguments, you can ignore 
        # higgs_bosson if it's only needed in the hook as done here
        return super().forward(inputs)


if __name__ == "__main__":
    x = NeoEmbeddings(10, 5, 1)
    # You should call () instead of forward so the hooks register appropriately
    print(x(torch.tensor([0, 2, 5, 8]), 1))

You can't do it in more succinct way, but the limitation is base's class forward method, not the hook itself (and tbh I wouldn't want it to be more succinct as it would become unreadable IMO).

Torchscript compatible

If you want to use torchscript (tested on 1.2.0) you could use composition instead of inheritance. All you have to change are merely two lines and your code may look something like this:

import torch

# Inherit from Module and register embedding as submodule
class NeoEmbeddings(torch.nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
        super().__init__()
        # Just use it as a container inside your own class
        self._embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

    @staticmethod
    def neo_genesis(module, inputs):
        net_input, higgs_bosson = inputs
        return net_input

    def forward(self, inputs: torch.Tensor, higgs_bosson: torch.Tensor):
        return self._embedding(inputs)


if __name__ == "__main__":
    x = torch.jit.script(NeoEmbeddings(10, 5, 1))
    # All arguments must be tensors in torchscript
    print(x(torch.tensor([0, 2, 5, 8]), torch.tensor([1])))
like image 146
Szymon Maszke Avatar answered Sep 24 '22 21:09

Szymon Maszke


Since a forward pre-hook is called with only the tensor by definition, a keyword argument doesn't make much sense here. What would make more sense is to use an instance attribute for example:

def neo_genesis(self, input):
    if self.higgs_bosson:
        input = input + self.higgs_bosson
    return input

Then you can switch that attribute as appropriate. You could also use a context manager for that:

from contextlib import contextmanager

@contextmanager
def HiggsBoson(module):
    module.higgs_boson = 1
    yield
    module.higgs_boson = 0

with HiggsBoson(x):
    x.forward(...)

If you have that function already and you really need to change that parameter you can still replace the function's __defaults__ attribute:

x.neo_genesis.__defaults__ = (1,)  # this corresponds to `higgs_boson` parameter
x.forward(...)
x.neo_genesis.__defaults__ = (0,)  # reset to default
like image 40
a_guest Avatar answered Sep 23 '22 21:09

a_guest