Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can I make my custom pytorch modules behave differently when train() or eval() are called?

Tags:

python

pytorch

According to the official documents, using train() or eval() will have effects on certain modules. However, now I wish to achieve a similar thing with my custom module, i.e. it does something when train() is turned on, and something different when eval() is turned on. How can I do this?

like image 373
ihdv Avatar asked Jul 31 '20 02:07

ihdv


People also ask

What does eval () do in PyTorch?

eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model. train() to set these layers to training mode.

What does train () do in PyTorch?

train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation.

What does module eval do?

model. eval() is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc.

What does model train and model eval do?

train() and evaluation model. eval() will automatically set the mode for the dropout layer and batch normalization layers and rescale appropriately so that we do not have to worry about that at all.


1 Answers

Yes, you can.

As you can see in the source code, eval() and train() are basically changing a flag called self.training (note that it is called recursively):

def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

def eval(self: T) -> T:
    return self.train(False)

This flag is available in every nn.Module. If your custom module inherits this base class, then it is quite simple to achieve what you want:

import torch.nn as nn


class MyCustomModule(nn.Module):
    def __init__(self):
        super().__init__()
        # [...]

    def forward(self, x):
        if self.training:
            # train() -> training logic
        else:
            # eval()  -> inference logic
like image 158
Berriel Avatar answered Oct 21 '22 15:10

Berriel