See the code snippet:
import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)
The output is tensor([0.])
, but
import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)
The output is None
.
I'm confused that why the output of torch.where
is tensor([0.])
?
import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)
The output is tensor([2., 0.])
. (a[0, 0] * a[0, 1])
is not in any way related to b[1]
, but the gradient of b[1]
is 0
not None
.
Source code and UsageTorch was written in Lua while PyTorch was written in Python. PyTorch and Torch use the same C libraries that contain all the performance such as: TH. THC.
Torch is an open-source machine learning library, a scientific computing framework, and a script language based on the Lua programming language. It provides a wide range of algorithms for deep learning, and uses the scripting language LuaJIT, and an underlying C implementation.
The forward function computes output Tensors from input Tensors. The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value.
all differentiable torch. * functions are enabled to be differentiable. Functions such as indexing (and a few more) cannot be differentiable wrt their indices and some inputs. When you try to differentiate such ones, an appropriate error is thrown.
Tracking based AD, like pytorch, works by tracking. You can't track through things that are not function calls intercepted by the library. By using an if
statement like this, there's no connection between x
and y
, whereas with where
, x
and y
are linked in the expression tree.
Now, for the differences:
0
is the correct derivative of the function x ↦ x > 0 ? x : 2
at the point -1
(since the negative side is constant).x
is not in any way related to y
(in the else
branch). Therefore, the derivative of y
given x
is undefined, which is represented as None
.(You can do such things even in Python, but that requires more sophisticated technology like source transformation. I don't thing it is possible with pytorch.)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With