Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Autograd gives different gradients when using .clamp instead of torch.relu

I'm still working on my understanding of the PyTorch autograd system. One thing I'm struggling at is to understand why .clamp(min=0) and nn.functional.relu() seem to have different backward passes.

It's especially confusing as .clamp is used equivalently to relu in PyTorch tutorials, such as https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn.

I found this when analysing the gradients of a simple fully connected net with one hidden layer and a relu activation (linear in the outputlayer).

to my understanding the output of the following code should be just zeros. I hope someone can show me what I am missing.

import torch
dtype = torch.float

x = torch.tensor([[3,2,1],
                  [1,0,2],
                  [4,1,2],
                  [0,0,1]], dtype=dtype)

y = torch.ones(4,4)

w1_a = torch.tensor([[1,2],
                     [0,1],
                     [4,0]], dtype=dtype, requires_grad=True)
w1_b = w1_a.clone().detach()
w1_b.requires_grad = True



w2_a = torch.tensor([[-1, 1],
                     [-2, 3]], dtype=dtype, requires_grad=True)
w2_b = w2_a.clone().detach()
w2_b.requires_grad = True


y_hat_a = torch.nn.functional.relu(x.mm(w1_a)).mm(w2_a)
y_a = torch.ones_like(y_hat_a)
y_hat_b = x.mm(w1_b).clamp(min=0).mm(w2_b)
y_b = torch.ones_like(y_hat_b)

loss_a = (y_hat_a - y_a).pow(2).sum()
loss_b = (y_hat_b - y_b).pow(2).sum()

loss_a.backward()
loss_b.backward()

print(w1_a.grad - w1_b.grad)
print(w2_a.grad - w2_b.grad)

# OUT:
# tensor([[  0.,   0.],
#         [  0.,   0.],
#         [  0., -38.]])
# tensor([[0., 0.],
#         [0., 0.]])
# 
like image 358
DaFlooo Avatar asked Mar 10 '20 13:03

DaFlooo


People also ask

Is PyTorch clamp differentiable?

The activation function is continuous but not differentiable at 0.

How does Autograd work in PyTorch?

Autograd is reverse automatic differentiation system. Conceptually, autograd records a graph recording all of the operations that created the data as you execute operations, giving you a directed acyclic graph whose leaves are the input tensors and roots are the output tensors.

How does PyTorch calculate gradient?

PyTorch computes the gradient of a function with respect to the inputs by using automatic differentiation. Automatic differentiation is a technique that, given a computational graph, calculates the gradients of the inputs. Automatic differentiation can be performed in two different ways; forward and reverse mode.

How do you stop gradients in PyTorch?

If you know during the forward which part you want to block the gradients from, you can use . detach() on the output of this block to exclude it from the backward.


1 Answers

The reason is that clamp and relu produce different gradients at 0. Checking with a scalar tensor x = 0 the two versions: (x.clamp(min=0) - 1.0).pow(2).backward() versus (relu(x) - 1.0).pow(2).backward(). The resulting x.grad is 0 for the relu version but it is -2 for the clamp version. That means relu chooses x == 0 --> grad = 0 while clamp chooses x == 0 --> grad = 1.

like image 170
a_guest Avatar answered Oct 10 '22 02:10

a_guest