Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Does pytorch do eager pruning of its computational graph?

This is a very simple example:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

This will print,

tensor([2., 2., 0., 0., 0.]),

since, of course, ds/dx is zero for the entries where z is zero.

My question is: Is pytorch smart and stop the computations when it reaches a zero? Or does in fact do the calculation "2*5", only to later do "10 * 0 = 0"?

In this simple example it doesn't make a big difference, but in the (bigger) problem I am looking at, this will make a difference.

Thank you for any input.

like image 739
Julius Avatar asked Mar 08 '26 17:03

Julius


1 Answers

No, pytorch does no such thing as pruning any subsequent calculations when zero is reached. Even worse, due to how float arithmetic works all subsequent multiplication by zero will take roughly the same time as any regular multiplication.

For some cases there are ways around it though, for example if you want to use a masked loss you can just set the masked outputs to be zero, or detach them from gradients.

This example makes the difference clear:

def time_backward(do_detach):
    x = torch.tensor(torch.rand(100000000), requires_grad=True)
    y = torch.tensor(torch.rand(100000000), requires_grad=True)
    s2 = torch.sum(x * y)
    s1 = torch.sum(x * y)
    if do_detach:
        s2 = s2.detach()
    s = s1 + 0 * s2
    t = time.time()
    s.backward()
    print(time.time() - t)

time_backward(do_detach= False)
time_backward(do_detach= True)

outputs:

0.502875089645
0.198422908783
like image 130
Chris Holland Avatar answered Mar 11 '26 05:03

Chris Holland



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!