Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

AdamW and Adam with weight decay

Tags:

pytorch

Is there any difference between torch.optim.Adam(weight_decay=0.01) and torch.optim.AdamW(weight_decay=0.01)?

Link to the docs: torch.optim

like image 325
Long Luu Avatar asked Oct 31 '20 12:10

Long Luu


Video Answer


1 Answers

Yes, Adam and AdamW weight decay are different.

Hutter pointed out in their paper (Decoupled Weight Decay Regularization) that the way weight decay is implemented in Adam in every library seems to be wrong, and proposed a simple way (which they call AdamW) to fix it.

In Adam, the weight decay is usually implemented by adding wd*w (wd is weight decay here) to the gradients (Ist case), rather than actually subtracting from weights (IInd case).

# Ist: Adam weight decay implementation (L2 regularization)
final_loss = loss + wd * all_weights.pow(2).sum() / 2
# IInd: equivalent to this in SGD
w = w - lr * w.grad - lr * wd * w

These methods are same for vanilla SGD, but as soon as we add momentum, or use a more sophisticated optimizer like Adam, L2 regularization (first equation) and weight decay (second equation) become different.

AdamW follows the second equation for weight decay.

In Adam

weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

In AdamW

weight_decay (float, optional) – weight decay coefficient (default: 1e-2)

Read more on the fastai blog.

like image 165
kHarshit Avatar answered Oct 11 '22 18:10

kHarshit