Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to wrap PyTorch functions and implement autograd?

I'm working through the PyTorch tutorial on Defining new autograd functions. The autograd function I want to implement is a wrapper around torch.nn.functional.max_pool1d. Here is what I have so far:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag

class SquareAndMaxPool1d(tag.Function):

    @staticmethod
    def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
                return_indices=False, ceil_mode=False):
        ctx.save_for_backward( input )

        inputC = input.clone() #copy input
        inputC *= inputC

        output = F.max_pool1d(inputC, kernel_size, stride=stride, \
                              padding=padding, dilation=dilation, \
                              return_indices=return_indices, \
                              ceil_mode=ceil_mode)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = get_max_pool1d_grad_somehow(grad_output)
        return 2.0*input*grad_input

My question is: how to I get the gradient of the wrapped function? I know that there are probably other ways to do this given how simple the example I present is, but what I want to do fits this framework and requires me to implement an autograd function.

Edit: After examining this blog post I decided to try the following for backward:

def backward(ctx, grad_output):
    input, output = ctx.saved_tensors
    grad_input = output.backward(grad_output)
    return 2.0*input*grad_input

with output added to the saved variables. I then run the following code:

x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()

and I get Bus error: 10.

Say, xT is tensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64), then I would expect to find that xT.grad is tensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64) after calling s.backward() (that is 2*x*grad_of_max_pool, with grad_of_max_pool containing tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64)).

I've figured out why I get a Bus error: 10. It appears that the above code leads to a recursive call of my backward at grad_input = output.backward(grad_output). So I need to find some other way to get the gradient of max_pool1d. I know how to implement this in pure Python, but the result would be much slower than if I could wrap the library code.

like image 808
Sean Lake Avatar asked Dec 05 '22 10:12

Sean Lake


2 Answers

You have picked a rather unlucky example. torch.nn.functional.max_pool1d is not an instance of torch.autograd.Function, because it's a PyTorch built-in, defined in C++ code and with an autogenerated Python binding. I am not sure if it's possible to get the backward property via its interface.

Firstly, in case you haven't noticed, you don't need to write any custom code for backpropagation of this formula because both power operation and max_pool1d already have it defined, so their composition also is covered by the autograd. Assuming your goal is an exercise, I would suggest you do it more manually (without falling back to backward of max_pool1d). An example is below

import torch
import torch.nn.functional as F
import torch.autograd as tag

class SquareAndMaxPool1d(tag.Function):
    @staticmethod
    def forward(ctx, input, kernel_size, **kwargs):
        # we're gonna need indices for backward. Currently SquareAnd...
        # never actually returns indices, I left it out for simplicity
        kwargs['return_indices'] = True

        input_sqr = input ** 2
        output, indices = F.max_pool1d(input_sqr, kernel_size, **kwargs)
        ctx.save_for_backward(input, indices)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, indices = ctx.saved_tensors

        # first we need to reconstruct the gradient of `max_pool1d`
        # by putting all the output gradient elements (corresponding to
        # input elements which made it through the max_pool1d) in their
        # respective places, the rest has gradient of 0. We do it by
        # scattering it against a tensor of 0s
        grad_output_unpooled = torch.zeros_like(input)
        grad_output_unpooled.scatter_(2, indices, grad_output)

        # then incorporate the gradient of the "square" part of your
        # operator
        grad_input = 2. * input * grad_output_unpooled

        # the docs for backward
        # https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward
        # say that "it should return as many tensors, as there were inputs
        # to forward()". It fails to mention that if an argument was not a
        # tensor, it should return None (I remember reading this somewhere,
        # but can't find it anymore). Anyway, we need to
        # return a (grad_input, None) tuple to avoid a complaint that two
        # outputs were expected
        return grad_input, None

We can then use the numerical gradient checker to verify that the operation works as expected.

f = SquareAndMaxPool1d.apply
xT = torch.randn(1, 1, 6, requires_grad=True, dtype=torch.float64)
tag.gradcheck(lambda t: f(t, 2), xT)

I'm sorry if this doesn't address your question of how to get the backward of max_pool1d, but hopefully you find my answer useful enough.

like image 190
Jatentaki Avatar answered Jan 31 '23 04:01

Jatentaki


The problems you had with the recursive calls is actually coming from the output and the fact that by default the with no_grad is a default behavior it seems in class declaration inherited from torch.autograd.Function. If you check output.grad_fn in forward, it will probably be None, and in backward, it will probably link to the function object <SquareAndMaxPool1d...> thus causing the recursive calls. If you are still interested in how to do exactly what you asked, here is an example with F.linear:

import torch
import torch.nn.functional as F

class custom_Linear(nn.Linear):
    def forward(self, _input):
        return Custom_Linear_AGfn_getAround.apply(_input, self.weight, self.bias)

class Custom_Linear_AGfn_getAround(torch.autograd.Function):
    @staticmethod
    def forward(ctx, _input, _weight, _bias):
        print('Custom forward')
        with torch.enable_grad():
            detached_input = _input.detach()
            detached_input.requires_grad_(True)
            detached_weight = _weight.detach()
            detached_weight.requires_grad_(True)
            detached_bias = _bias.detach()
            detached_bias.requires_grad_(True)
            _tmp = F.linear(detached_input, detached_weight, detached_bias)
        ctx.saved_input = detached_input
        ctx.saved_param = detached_weight, detached_bias
        ctx.save_for_backward(_tmp)
        _output = _tmp.detach()
        return _output

    @staticmethod
    def backward(ctx, grad_out):
        print('Custom backward')
        _tmp, = ctx.saved_tensors
        _weight, _bias = ctx.saved_param
        detached_input = ctx.saved_input
        with torch.enable_grad():
            _tmp.backward(grad_out)
        return detached_input.grad, _weight.grad, _bias.grad

Basically, it is just about constructing a small isolated graph for the part of interest without messing up with the main graph, and using grad_fn and requires_grad to keep track of the graphs when looking at what to detach and what is needed for the isolated graph.

About the tricky parts:

  • detaching the weight and bias: you could go without but EITHER you then pass _weight and _bias through the save_for_backward and will have _weight.grad, _bias.grad as None inside backward BUT once outside _weight.grad, _bias.grad will have their correct values, OR you pass them through an attribute as say ctx.saved_param, in which case, you will have to manually put None for the last two returned values of backward (return detached_input.grad, None, None), otherwise you will obtain twice the correct value when you check the weight and bias gradient outside of backward afterwards.
  • as said at the beginning, backward and forward for inherited class of torch.autograd.Function seems to have a with no_grad behavior by default. Thus, removing with torch.enable_grad(): in the above code will result in _tmp.grad_fn being None (Could not understand why by default _tmp had grad_fn to None and requires_grad to False in forward despite having required the gradient for detached_input until I bumped into: https://github.com/pytorch/pytorch/issues/7698)
  • I believe but I did not check that you might get a double grad_fn for the _output if you do not detach it as when I do not have the with torch.enable_grad() and do not detach the output, resulting in _tmp.grad_fn being None in forward, it does acquire <Custom_Linear_AGfn_getAround...> grad_fn in the backward (and results in the infinite recursive calls).
like image 23
Romain Renard Avatar answered Jan 31 '23 05:01

Romain Renard