This is a very simple example:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
This will print,
tensor([2., 2., 0., 0., 0.]),
since, of course, ds/dx is zero for the entries where z is zero.
My question is: Is pytorch smart and stop the computations when it reaches a zero? Or does in fact do the calculation "2*5", only to later do "10 * 0 = 0"?
In this simple example it doesn't make a big difference, but in the (bigger) problem I am looking at, this will make a difference.
Thank you for any input.
No, pytorch does no such thing as pruning any subsequent calculations when zero is reached. Even worse, due to how float arithmetic works all subsequent multiplication by zero will take roughly the same time as any regular multiplication.
For some cases there are ways around it though, for example if you want to use a masked loss you can just set the masked outputs to be zero, or detach them from gradients.
This example makes the difference clear:
def time_backward(do_detach):
x = torch.tensor(torch.rand(100000000), requires_grad=True)
y = torch.tensor(torch.rand(100000000), requires_grad=True)
s2 = torch.sum(x * y)
s1 = torch.sum(x * y)
if do_detach:
s2 = s2.detach()
s = s1 + 0 * s2
t = time.time()
s.backward()
print(time.time() - t)
time_backward(do_detach= False)
time_backward(do_detach= True)
outputs:
0.502875089645
0.198422908783
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With