Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Which PyTorch modules are affected by model.eval() and model.train()?

The model.eval() method modifies certain modules (layers) which are required to behave differently during training and inference. Some examples are listed in the docs:

This has [an] effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Is there an exhaustive list of which modules are affected?

like image 993
iacob Avatar asked Dec 30 '22 16:12

iacob


1 Answers

In addition to info provided by @iacob:

Base class Module Criteria
RNNBase RNN
LSTM
GRU
dropout > 0 (default: 0)
Transformer layers Transformer
TransformerEncoder
TransformerDecoder
dropout > 0 (Transformer default: 0.1)
Lazy variants LazyBatchNorm
currently nightly
merged PR
track_running_stats=True
like image 65
Szymon Maszke Avatar answered May 12 '23 17:05

Szymon Maszke