Is there any difference between torch.optim.Adam(weight_decay=0.01)
and torch.optim.AdamW(weight_decay=0.01)
?
Link to the docs: torch.optim
Yes, Adam and AdamW weight decay are different.
Hutter pointed out in their paper (Decoupled Weight Decay Regularization) that the way weight decay is implemented in Adam in every library seems to be wrong, and proposed a simple way (which they call AdamW) to fix it.
In Adam, the weight decay is usually implemented by adding wd*w
(wd
is weight decay here) to the gradients (Ist case), rather than actually subtracting from weights (IInd case).
# Ist: Adam weight decay implementation (L2 regularization)
final_loss = loss + wd * all_weights.pow(2).sum() / 2
# IInd: equivalent to this in SGD
w = w - lr * w.grad - lr * wd * w
These methods are same for vanilla SGD, but as soon as we add momentum, or use a more sophisticated optimizer like Adam, L2 regularization (first equation) and weight decay (second equation) become different.
AdamW follows the second equation for weight decay.
In Adam
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
In AdamW
weight_decay (float, optional) – weight decay coefficient (default: 1e-2)
Read more on the fastai blog.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With