Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the purpose of with torch.no_grad():

Consider the following code for Linear Regression implemented using PyTorch:

X is the input, Y is the output for the training set, w is the parameter that needs to be optimised
import torch

X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

def forward(x):
    return w * x

def loss(y, y_pred):
    return ((y_pred - y)**2).mean()

print(f'Prediction before training: f(5) = {forward(5).item():.3f}')

learning_rate = 0.01
n_iters = 100

for epoch in range(n_iters):
    # predict = forward pass
    y_pred = forward(X)

    # loss
    l = loss(Y, y_pred)

    # calculate gradients = backward pass
    l.backward()

    # update weights
    #w.data = w.data - learning_rate * w.grad
    with torch.no_grad():
        w -= learning_rate * w.grad
    
    # zero the gradients after updating
    w.grad.zero_()

    if epoch % 10 == 0:
        print(f'epoch {epoch+1}: w = {w.item():.3f}, loss = {l.item():.8f}')

What does the 'with' block do? The requires_grad argument for w is already set to True. Why is it then being put under a with torch.no_grad() block?

like image 599
Gaurav_Misra Avatar asked Nov 21 '25 23:11

Gaurav_Misra


1 Answers

The requires_grad argument tells PyTorch that we want to be able to calculate the gradients for those values. However, the with torch.no_grad() tells PyTorch to not calculate the gradients, and the program explicitly uses it here (as with most neural networks) in order to not update the gradients when it is updating the weights as that would affect the back propagation.

like image 135
Andrei Avatar answered Nov 24 '25 13:11

Andrei



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!