Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

what is the difference between if-else statement and torch.where in pytorch?

Tags:

See the code snippet:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

The output is tensor([0.]), but

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

The output is None.

I'm confused that why the output of torch.where is tensor([0.])?

update

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

The output is tensor([2., 0.]). (a[0, 0] * a[0, 1]) is not in any way related to b[1], but the gradient of b[1] is 0 not None.

like image 731
gaussclb Avatar asked Apr 13 '20 08:04

gaussclb


People also ask

What is the difference between PyTorch and Torch?

Source code and UsageTorch was written in Lua while PyTorch was written in Python. PyTorch and Torch use the same C libraries that contain all the performance such as: TH. THC.

What is Torch function in Python?

Torch is an open-source machine learning library, a scientific computing framework, and a script language based on the Lua programming language. It provides a wide range of algorithms for deep learning, and uses the scripting language LuaJIT, and an underlying C implementation.

What is forward and backward in PyTorch?

The forward function computes output Tensors from input Tensors. The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value.

Is Torch mean differentiable?

all differentiable torch. * functions are enabled to be differentiable. Functions such as indexing (and a few more) cannot be differentiable wrt their indices and some inputs. When you try to differentiate such ones, an appropriate error is thrown.


1 Answers

Tracking based AD, like pytorch, works by tracking. You can't track through things that are not function calls intercepted by the library. By using an if statement like this, there's no connection between x and y, whereas with where, x and y are linked in the expression tree.

Now, for the differences:

  • In the first snippet, 0 is the correct derivative of the function x ↦ x > 0 ? x : 2 at the point -1 (since the negative side is constant).
  • In the second snippet, as I said, x is not in any way related to y (in the else branch). Therefore, the derivative of y given x is undefined, which is represented as None.

(You can do such things even in Python, but that requires more sophisticated technology like source transformation. I don't thing it is possible with pytorch.)

like image 51
phipsgabler Avatar answered Sep 16 '22 11:09

phipsgabler