I'm working through the PyTorch tutorial on Defining new autograd functions. The autograd function I want to implement is a wrapper around torch.nn.functional.max_pool1d
. Here is what I have so far:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
return_indices=False, ceil_mode=False):
ctx.save_for_backward( input )
inputC = input.clone() #copy input
inputC *= inputC
output = F.max_pool1d(inputC, kernel_size, stride=stride, \
padding=padding, dilation=dilation, \
return_indices=return_indices, \
ceil_mode=ceil_mode)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = get_max_pool1d_grad_somehow(grad_output)
return 2.0*input*grad_input
My question is: how to I get the gradient of the wrapped function? I know that there are probably other ways to do this given how simple the example I present is, but what I want to do fits this framework and requires me to implement an autograd
function.
Edit: After examining this blog post I decided to try the following for backward
:
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = output.backward(grad_output)
return 2.0*input*grad_input
with output
added to the saved variables. I then run the following code:
x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()
and I get Bus error: 10
.
Say, xT
is tensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64)
, then I would expect to find that xT.grad
is tensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64)
after calling s.backward()
(that is 2*x*grad_of_max_pool
, with grad_of_max_pool
containing tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64)
).
I've figured out why I get a Bus error: 10
. It appears that the above code leads to a recursive call of my backward
at grad_input = output.backward(grad_output)
. So I need to find some other way to get the gradient of max_pool1d
. I know how to implement this in pure Python, but the result would be much slower than if I could wrap the library code.
You have picked a rather unlucky example. torch.nn.functional.max_pool1d
is not an instance of torch.autograd.Function
, because it's a PyTorch built-in, defined in C++ code and with an autogenerated Python binding. I am not sure if it's possible to get the backward
property via its interface.
Firstly, in case you haven't noticed, you don't need to write any custom code for backpropagation of this formula because both power operation and max_pool1d
already have it defined, so their composition also is covered by the autograd. Assuming your goal is an exercise, I would suggest you do it more manually (without falling back to backward
of max_pool1d
). An example is below
import torch
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, **kwargs):
# we're gonna need indices for backward. Currently SquareAnd...
# never actually returns indices, I left it out for simplicity
kwargs['return_indices'] = True
input_sqr = input ** 2
output, indices = F.max_pool1d(input_sqr, kernel_size, **kwargs)
ctx.save_for_backward(input, indices)
return output
@staticmethod
def backward(ctx, grad_output):
input, indices = ctx.saved_tensors
# first we need to reconstruct the gradient of `max_pool1d`
# by putting all the output gradient elements (corresponding to
# input elements which made it through the max_pool1d) in their
# respective places, the rest has gradient of 0. We do it by
# scattering it against a tensor of 0s
grad_output_unpooled = torch.zeros_like(input)
grad_output_unpooled.scatter_(2, indices, grad_output)
# then incorporate the gradient of the "square" part of your
# operator
grad_input = 2. * input * grad_output_unpooled
# the docs for backward
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward
# say that "it should return as many tensors, as there were inputs
# to forward()". It fails to mention that if an argument was not a
# tensor, it should return None (I remember reading this somewhere,
# but can't find it anymore). Anyway, we need to
# return a (grad_input, None) tuple to avoid a complaint that two
# outputs were expected
return grad_input, None
We can then use the numerical gradient checker to verify that the operation works as expected.
f = SquareAndMaxPool1d.apply
xT = torch.randn(1, 1, 6, requires_grad=True, dtype=torch.float64)
tag.gradcheck(lambda t: f(t, 2), xT)
I'm sorry if this doesn't address your question of how to get the backward
of max_pool1d
, but hopefully you find my answer useful enough.
The problems you had with the recursive calls is actually coming from the output
and the fact that by default the with no_grad
is a default behavior it seems in class declaration inherited from torch.autograd.Function
. If you check output.grad_fn
in forward
, it will probably be None
, and in backward
, it will probably link to the function object <SquareAndMaxPool1d...>
thus causing the recursive calls. If you are still interested in how to do exactly what you asked, here is an example with F.linear
:
import torch
import torch.nn.functional as F
class custom_Linear(nn.Linear):
def forward(self, _input):
return Custom_Linear_AGfn_getAround.apply(_input, self.weight, self.bias)
class Custom_Linear_AGfn_getAround(torch.autograd.Function):
@staticmethod
def forward(ctx, _input, _weight, _bias):
print('Custom forward')
with torch.enable_grad():
detached_input = _input.detach()
detached_input.requires_grad_(True)
detached_weight = _weight.detach()
detached_weight.requires_grad_(True)
detached_bias = _bias.detach()
detached_bias.requires_grad_(True)
_tmp = F.linear(detached_input, detached_weight, detached_bias)
ctx.saved_input = detached_input
ctx.saved_param = detached_weight, detached_bias
ctx.save_for_backward(_tmp)
_output = _tmp.detach()
return _output
@staticmethod
def backward(ctx, grad_out):
print('Custom backward')
_tmp, = ctx.saved_tensors
_weight, _bias = ctx.saved_param
detached_input = ctx.saved_input
with torch.enable_grad():
_tmp.backward(grad_out)
return detached_input.grad, _weight.grad, _bias.grad
Basically, it is just about constructing a small isolated graph for the part of interest without messing up with the main graph, and using grad_fn
and requires_grad
to keep track of the graphs when looking at what to detach and what is needed for the isolated graph.
About the tricky parts:
_weight
and _bias
through the save_for_backward
and will have _weight.grad
, _bias.grad
as None
inside backward
BUT once outside _weight.grad
, _bias.grad
will have their correct values, OR you pass them through an attribute as say ctx.saved_param
, in which case, you will have to manually put None
for the last two returned values of backward
(return detached_input.grad, None, None
), otherwise you will obtain twice the correct value when you check the weight and bias gradient outside of backward afterwards.backward
and forward
for inherited class of torch.autograd.Function
seems to have a with no_grad
behavior by default. Thus, removing with torch.enable_grad():
in the above code will result in _tmp.grad_fn
being None
(Could not understand why by default _tmp
had grad_fn
to None
and requires_grad
to False
in forward
despite having required the gradient for detached_input
until I bumped into: https://github.com/pytorch/pytorch/issues/7698)grad_fn
for the _output
if you do not detach it as when I do not have the with torch.enable_grad()
and do not detach the output, resulting in _tmp.grad_fn
being None in forward, it does acquire <Custom_Linear_AGfn_getAround...>
grad_fn
in the backward
(and results in the infinite recursive calls).If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With