Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch gradient differs from manually calculated gradient

I'm trying to compute the gradient of 1/x without using Pytorch's autograd. I use the formula grad(1/x, x) = -1/x**2. When I compare my result with this formula to the gradient given by Pytorch's autograd, they're different.

Here is my code:

a = torch.tensor(np.random.randn(), dtype=dtype, requires_grad=True)
loss = 1/a
loss.backward()
print(a.grad - (-1/(a**2)))

The output is:

tensor(5.9605e-08, grad_fn=<ThAddBackward>)

Can anyone explain to me what the problem is?

like image 569
HOANG GIANG Avatar asked Nov 13 '18 05:11

HOANG GIANG


1 Answers

So I guess you expect zero as result. When you take a closer look you see that it is quite close. When deviding numbers on a binary system (computer) then you often get round-off errors.

Lets take a look at your example with an additional print-statement added:

a = torch.tensor(np.random.randn(), requires_grad=True)
loss = 1/a
loss.backward()
print(a.grad, (-1/(a**2)))
print(a.grad - (-1/(a**2)))

Because you use a random input the output is of course random too.
(so you won't get this very same numbers, but just repeat this experiment and you will have similar examples).

Sometimes you will get zero as result. But this was not the case in your initial example:

tensor(-0.9074) tensor(-0.9074, grad_fn=<MulBackward>)
tensor(5.9605e-08, grad_fn=<ThSubBackward>)

You see even though both are displayed as the same number but they differ in one of the last decimal places. That is why you get this very small difference when subtracting both.

This problem as a general problem of computers, some fractions just have a large or infinite number of decimal places, but the memory of your computer has not. So they are cut off at some point.

So what you experience here is actually a lack of precision. And the precision is depending on the numerical data type you are using (i.e. torch.float32 or torch.float64).

You can also take a look here more information:
https://en.wikipedia.org/wiki/Double-precision_floating-point_format


But this is not specific to PyTorch or so, here is a Python example:

print(29/100*100)

Results in:

28.999999999999996

Edit:

As @HOANG GIANG pointed out, changing the equation to -(1/a)*(1/a) works well and the result is zero. This is probably the case because the calculation in done to calculate the gradient is very similar (or the same) to -(1/a)*(1/a) in this case. Therefore it shares the same round-off errors therefore the difference is zero.

So then here is another more fitting example than the one above. Even though -(1/x)*(1/x) is mathematically equivalent to -1/x^2 it is not always the same when calculating it on the computer, depending on the value of x:

import numpy as np
print('e1 == e2','x value', '\t'*2, 'round-off error', sep='\t')
print('='*70)
for i in range(10):
    x = np.random.randn()
    e1 = -(1/x)*(1/x)
    e2 = (-1/(x**2))
    print(e1 == e2, x, e1-e2, sep='\t\t')

Output:

e1 == e2    x value                 round-off error
======================================================================
True        0.2934154339948173      0.0
True        -1.2881863891014191     0.0
True        1.0463038021843876      0.0
True        -0.3388766143622498     0.0
True        -0.6915415747192347     0.0
False       1.3299049850551317      1.1102230246251565e-16
True        -1.2392046539563553     0.0
False       -0.42534236747121645    8.881784197001252e-16
True        1.407198823994324       0.0
False       -0.21798652132356966    3.552713678800501e-15

Even though the round-off error seems to be a bit less (I tried different random values, and rarely more than two out of ten had a round-off error), but still there are already small differences when just calculating 1/x:

import numpy as np
print('e1 == e2','x value', '\t'*2, 'round-off error', sep='\t')
print('='*70)
for i in range(10):
    x = np.random.randn()
    # calculate 1/x
    result = 1/x
    # apply inverse function
    reconstructed_x = 1/result
    # mathematically this should be the same as x
    print(x == reconstructed_x, x, x-reconstructed_x, sep='\t\t')

Output:

e1 == e2    x value             round-off error
======================================================================
False       0.9382823115235075      1.1102230246251565e-16
True        -0.5081217386356917     0.0
True        -0.04229436058156134    0.0
True        1.1121100294357302      0.0
False       0.4974618312372863      -5.551115123125783e-17
True        -0.20409933212316553    0.0
True        -0.6501652554924282     0.0
True        -3.048057937738731      0.0
True        1.6236075700470816      0.0
True        0.4936926651641918      0.0
like image 63
MBT Avatar answered Sep 28 '22 07:09

MBT