Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch set_grad_enabled(False) vs with no_grad():

Tags:

pytorch

Assuming autograd is on (as it is by default), is there any difference (besides indent) between doing:

with torch.no_grad():
    <code>

and

torch.set_grad_enabled(False)
<code>
torch.set_grad_enabled(True)
like image 472
Tom Hale Avatar asked Nov 23 '18 13:11

Tom Hale


People also ask

Is Torch No_grad () the same as model eval ()?

model. eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval mode instead of training mode. torch. no_grad() impacts the autograd engine and deactivate it.

What does torch set_ grad_ enabled do?

set_grad_enabled is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare.

What is with torch No_grad ()?

The use of "with torch. no_grad()" is like a loop where every tensor inside the loop will have requires_grad set to False. It means any tensor with gradient currently attached with the current computational graph is now detached from the current graph.

What is Requires_grad in Pytorch?

PyTorchServer Side ProgrammingProgramming. To create a tensor with gradients, we use an extra parameter "requires_grad = True" while creating a tensor. requires_grad is a flag that controls whether a tensor requires a gradient or not. Only floating point and complex dtype tensors can require gradients.


1 Answers

Actually no, there no difference in the way used in the question. When you take a look at the source code of no_grad. You see that it is actually using torch.set_grad_enabled to archive this behaviour:

class no_grad(object):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.
    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    Also functions as a decorator.


    Example::

        >>> x = torch.tensor([1], requires_grad=True)
        >>> with torch.no_grad():
        ...   y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """

    def __init__(self):
        self.prev = torch.is_grad_enabled()

    def __enter__(self):
        torch._C.set_grad_enabled(False)

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)
        return False

    def __call__(self, func):
        @functools.wraps(func)
        def decorate_no_grad(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return decorate_no_grad

However there is an additional functionality of torch.set_grad_enabled over torch.no_grad when used in a with-statement which lets you control to switch on or off gradient computation:

    >>> x = torch.tensor([1], requires_grad=True)
    >>> is_train = False
    >>> with torch.set_grad_enabled(is_train):
    ...   y = x * 2
    >>> y.requires_grad

https://pytorch.org/docs/stable/_modules/torch/autograd/grad_mode.html


Edit:

@TomHale Regarding your comment. I just made a short test with PyTorch 1.0 and it turned out that the gradient will be active:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
torch.set_grad_enabled(False)
with torch.enable_grad():
    scalar = w.sum()
    scalar.backward()
    # Gradient tracking will be enabled here.
torch.set_grad_enabled(True)

print('Grad After:', w.grad)

Output:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])

So gradients will be computed in this setting.

The other setting you posted in your answer also yields to the same result:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
with torch.no_grad():
    with torch.enable_grad():
        # Gradient tracking IS enabled here.
        scalar = w.sum()
        scalar.backward()

print('Grad After:', w.grad)

Output:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
like image 142
MBT Avatar answered Sep 22 '22 05:09

MBT