Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch torch.no_grad() versus requires_grad=False

I'm following a PyTorch tutorial which uses the BERT NLP model (feature extractor) from the Huggingface Transformers library. There are two pieces of interrelated code for gradient updates that I don't understand.

(1) torch.no_grad()

The tutorial has a class where the forward() function creates a torch.no_grad() block around a call to the BERT feature extractor, like this:

bert = BertModel.from_pretrained('bert-base-uncased')

class BERTGRUSentiment(nn.Module):
    
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def forward(self, text):
        with torch.no_grad():
            embedded = self.bert(text)[0]

(2) param.requires_grad = False

There is another portion in the same tutorial where the BERT parameters are frozen.

for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

When would I need (1) and/or (2)?

  • If I want to train with a frozen BERT, would I need to enable both?
  • If I want to train to let BERT be updated, would I need to disable both?

Additionaly, I ran all four combinations and found:

   with torch.no_grad   requires_grad = False  Parameters  Ran
   ------------------   ---------------------  ----------  ---
a. Yes                  Yes                      3M        Successfully
b. Yes                  No                     112M        Successfully
c. No                   Yes                      3M        Successfully
d. No                   No                     112M        CUDA out of memory

Can someone please explain what's going on? Why am I getting CUDA out of memory for (d) but not (b)? Both have 112M learnable parameters.

like image 298
stackoverflowuser2010 Avatar asked Sep 07 '20 23:09

stackoverflowuser2010


People also ask

Is Requires_grad true by default PyTorch?

requires_grad (bool) – If autograd should record operations on this tensor. Default: True .

What does torch No_grad () mean?

class torch. no_grad[source] Context-manager that disabled gradient calculation. Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.

Is Requires_grad true by default?

requires_grad_() takes effect on all of the module's parameters (which have requires_grad=True by default).

What does Requires_grad false do?

requires_grad = False? If requires_grad is set to false, you are freezing the part of the model as no changes happen to its parameters. In the example below, all layers have the parameters modified during training as requires_grad is set to true.


1 Answers

This is an older discussion, which has changed slightly over the years (mainly due to the purpose of with torch.no_grad() as a pattern. An excellent answer that kind of answers your question as well can be found on Stackoverflow already.
However, since the original question is vastly different, I'll refrain from marking as duplicate, especially due to the second part about the memory.

An initial explanation of no_grad is given here:

with torch.no_grad() is a context manager and is used to prevent calculating gradients [...].

requires_grad on the other hand is used

to freeze part of your model and train the rest [...].

Source again the SO post.

Essentially, with requires_grad you are just disabling parts of a network, whereas no_grad will not store any gradients at all, since you're likely using it for inference and not training.
To analyze the behavior of your combinations of parameters, let us investigate what is happening:

  • a) and b) do not store any gradients at all, which means that you have vastly more memory available to you, no matter the number of parameters, since you're not retaining them for a potential backward pass.
  • c) has to store the forward pass for later backpropagation, however, only a limited number of parameter (3 million) are stored, which makes this still manageable.
  • d), however, needs to store the forward pass for all 112 million parameters, which causes you to run out of memory.
like image 129
dennlinger Avatar answered Sep 29 '22 15:09

dennlinger