When I want to evaluate the performance of my model on the validation set, is it preferred to use with torch.no_grad:
or model.eval()
?
PyTorch model eval train is defined as a process to evaluate the train data. The eval() function is used to evaluate the train model. The eval() is type of switch for a particular parts of model which act differently during training and evaluating time.
The use of "with torch. no_grad()" is like a loop where every tensor inside the loop will have requires_grad set to False. It means any tensor with gradient currently attached with the current computational graph is now detached from the current graph.
model. train() tells your model that you are training the model. So effectively layers like dropout, batchnorm etc. which behave different on the train and test procedures know what is going on and hence can behave accordingly.
Use both. They do different things, and have different scopes.
with torch.no_grad
- disables tracking of gradients in autograd
.model.eval()
changes the forward()
behaviour of the module it is called uponwith torch.no_grad
The torch.autograd.no_grad
documentation says:
Context-manager that disabled [sic] gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure that you will not call
Tensor.backward()
. It will reduce memory consumption for computations that would otherwise haverequires_grad=True
. In this mode, the result of every computation will haverequires_grad=False
, even when the inputs haverequires_grad=True
.
model.eval()
The nn.Module.eval
documentation says:
Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.
The creator of pytorch said the documentation should be updated to suggest the usage of both, and I raised the pull request.
with torch.no_grad:
disables computation of gradients for the backward pass. Since these calculations are unnecessary during inference, and add non-trivial computational overhead, it is essessential to use this context if evaluating the model's speed. It will not however affect results.
model.eval()
ensures certain modules which behave differently in training vs inference (e.g. Dropout and BatchNorm) are defined appropriately during the forward pass in inference. As such, if your model contains such modules it is essential to enable this.
For the reasons above it is good practice to use both during inference.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With