Pytorch beginner here! Consider the following custom Module:
class Testme(nn.Module):
def __init__(self):
super(Testme, self).__init__()
def forward(self, x):
return x / t_.max(x).expand_as(x)
As far as I understand the documentation:
I believe this could also be implemented as a custom Function
. A subclass of Function
requires a backward()
method, but the
Module
does not. As well, in the doc example of a Linear Module
, it depends on a Linear Function
:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
...
def forward(self, input):
return Linear()(input, self.weight, self.bias)
The question: I do not understand the relation between Module
and Function
. In the first listing above (module Testme
), should it have an associated function? If not, then it is possible to implement this without a backward
method by subclassing Module, so why does a Function
always require a backward
method?
Perhaps Function
s are intended only for functions that are not composed out of existing torch functions? Say differently: maybe modules do not need the associated Function
if their forward
method is composed entirely from previously defined torch functions?
PyTorch uses modules to represent neural networks. Modules are: Building blocks of stateful computation. PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for easy construction of elaborate, multi-layer neural networks.
The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.
ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. """
This information is gathered and summarised from the official PyTorch Documentaion.
torch.autograd.Function
really lies at the heart of the autograd package in PyTorch. Any graph you build in PyTorch and any operation you conduct on Variables
in PyTorch is based on a Function
. Any function requires an __init__(), forward()
and backward()
method (see more here: http://pytorch.org/docs/notes/extending.html) . This enables PyTorch to compute results and compute gradients for Variables
.
nn.Module()
in contrast is really just a convenience for organising your model, your different layers, etc. For example, it organises all the trainable parameters in your model in .parameters()
and allows you to add another layer to a model easily, etc. etc. It is not the place where you define a backward method, because in the forward()
method, you're supposed to use subclasses of Function()
, for which you have already defined backward()
. Hence, if you have specified the order of operations in forward()
, PyTorch already knows how to back-propagate gradients.
Now, when should you use what?
If you have an operation that is just a composition of existing implemented functions in PyTorch (like your thing above), there's really no point adding any subclass to Function() yourself. Because you can just stack operations up and build a dynamic graph. It's however a sensible idea to bunch these operations together. If any operation involves trainable parameters (for example a linear layer of a neural network), you should subclass nn.Module()
and bunch your operations together in the forward method. This allows you to easily access parameters (as outlined above) for use of torch.optim
, etc. If you don't have any trainable parameters, I would probably still bunch them together, but a standard Python function, where you take care of the instantination of each operation you use would be sufficient.
If you have a new custom operation (e.g. a new stochastic layer with some complicated sampling procedure), you should subclass Function()
and define __init__(), forward()
and backward()
to tell PyTorch how to compute results and how to compute gradients, when you use this operation. Afterwards, you should either create a functional version to take care of instantinating the function and use your operation or create a module, if your operation has trainable parameters. Again, you can read more about this in the link above.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With