Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

About pytorch learning rate scheduler

Tags:

python

pytorch

here is my code

optimizer = optim.SGD(net.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

for i in range(15):
    lr = scheduler.get_lr()[0]
    lr1 = optimizer.param_groups[0]["lr"]
    print(i, lr, lr1)
    scheduler.step()

And here is the result

0 0.1 0.1
1 0.1 0.1
2 0.1 0.1
3 0.1 0.1
4 0.1 0.1
5 0.025 0.05
6 0.05 0.05
7 0.05 0.05
8 0.05 0.05
9 0.05 0.05
10 0.0125 0.025
11 0.025 0.025
12 0.025 0.025
13 0.025 0.025
14 0.025 0.025

We can see that the when scheduler.step() is applied, the learning rate first decreases 0.25 times, then bounces back to 0.5 times. Is it the problem of scheduler.get_lr() lr or the problem of scheduler.step()

About the envirioment

  • python=3.6.9
  • pytorch=1.1.0

In addition, I can't find this problem when pytorch=0.4.1 is used.

like image 386
Vincent_Ho Avatar asked Jan 05 '20 11:01

Vincent_Ho


1 Answers

Yes, the "problem" is in the use of get_lr(). To get the current LR, what you need is actually the get_last_lr().


If you take a look at the implementation:

def get_lr(self):
    if not self._get_lr_called_within_step:
        warnings.warn("To get the last learning rate computed by the scheduler, "
                      "please use `get_last_lr()`.", UserWarning)

    if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
        return [group['lr'] for group in self.optimizer.param_groups]
    return [group['lr'] * self.gamma
            for group in self.optimizer.param_groups]

When it is in the step=5, it does not satisfy the conditions (because step_size=5), and it will return the lr * gamma. The awkward thing is that you should be getting a warning when you call get_lr() out the the step() function (as you can see in the implementation above) and apparently you didn't. The warning was added only 3 months ago, so you won't have it on v1.1.0.

For the sake of completeness, what the step() method does is that it adds 1 to the last_epoch and updates the LR by calling the get_lr() function (see here):

self.last_epoch += 1
values = self.get_lr()
like image 112
Berriel Avatar answered Sep 18 '22 11:09

Berriel