Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

when is a pytorch custom function needed (rather than only a module)?

Tags:

pytorch

torch

Pytorch beginner here! Consider the following custom Module:

class Testme(nn.Module):
    def __init__(self):
        super(Testme, self).__init__()

def forward(self, x):
    return x / t_.max(x).expand_as(x)

As far as I understand the documentation: I believe this could also be implemented as a custom Function. A subclass of Function requires a backward() method, but the Module does not. As well, in the doc example of a Linear Module, it depends on a Linear Function:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        ...    
    def forward(self, input):
        return Linear()(input, self.weight, self.bias)

The question: I do not understand the relation between Module and Function. In the first listing above (module Testme), should it have an associated function? If not, then it is possible to implement this without a backward method by subclassing Module, so why does a Function always require a backward method?

Perhaps Functions are intended only for functions that are not composed out of existing torch functions? Say differently: maybe modules do not need the associated Function if their forward method is composed entirely from previously defined torch functions?

like image 682
forgotmysocks Avatar asked Jun 08 '17 07:06

forgotmysocks


People also ask

What is a module in PyTorch?

PyTorch uses modules to represent neural networks. Modules are: Building blocks of stateful computation. PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for easy construction of elaborate, multi-layer neural networks.

What is model parameters () in PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

What is CTX in PyTorch?

ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. """


1 Answers

This information is gathered and summarised from the official PyTorch Documentaion.

torch.autograd.Functionreally lies at the heart of the autograd package in PyTorch. Any graph you build in PyTorch and any operation you conduct on Variables in PyTorch is based on a Function. Any function requires an __init__(), forward() and backward() method (see more here: http://pytorch.org/docs/notes/extending.html) . This enables PyTorch to compute results and compute gradients for Variables.

nn.Module()in contrast is really just a convenience for organising your model, your different layers, etc. For example, it organises all the trainable parameters in your model in .parameters()and allows you to add another layer to a model easily, etc. etc. It is not the place where you define a backward method, because in the forward() method, you're supposed to use subclasses of Function(), for which you have already defined backward(). Hence, if you have specified the order of operations in forward(), PyTorch already knows how to back-propagate gradients.

Now, when should you use what?

If you have an operation that is just a composition of existing implemented functions in PyTorch (like your thing above), there's really no point adding any subclass to Function() yourself. Because you can just stack operations up and build a dynamic graph. It's however a sensible idea to bunch these operations together. If any operation involves trainable parameters (for example a linear layer of a neural network), you should subclass nn.Module() and bunch your operations together in the forward method. This allows you to easily access parameters (as outlined above) for use of torch.optim, etc. If you don't have any trainable parameters, I would probably still bunch them together, but a standard Python function, where you take care of the instantination of each operation you use would be sufficient.

If you have a new custom operation (e.g. a new stochastic layer with some complicated sampling procedure), you should subclass Function() and define __init__(), forward() and backward() to tell PyTorch how to compute results and how to compute gradients, when you use this operation. Afterwards, you should either create a functional version to take care of instantinating the function and use your operation or create a module, if your operation has trainable parameters. Again, you can read more about this in the link above.

like image 131
mbpaulus Avatar answered Oct 15 '22 06:10

mbpaulus