According to the official documents, using train()
or eval()
will have effects on certain modules. However, now I wish to achieve a similar thing with my custom module, i.e. it does something when train()
is turned on, and something different when eval()
is turned on. How can I do this?
eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model. train() to set these layers to training mode.
train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation.
model. eval() is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc.
train() and evaluation model. eval() will automatically set the mode for the dropout layer and batch normalization layers and rescale appropriately so that we do not have to worry about that at all.
Yes, you can.
As you can see in the source code, eval()
and train()
are basically changing a flag called self.training
(note that it is called recursively):
def train(self: T, mode: bool = True) -> T:
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T:
return self.train(False)
This flag is available in every nn.Module
. If your custom module inherits this base class, then it is quite simple to achieve what you want:
import torch.nn as nn
class MyCustomModule(nn.Module):
def __init__(self):
super().__init__()
# [...]
def forward(self, x):
if self.training:
# train() -> training logic
else:
# eval() -> inference logic
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With