Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the param `last_epoch` on Pytorch Optimizer's Schedulers is for?

Tags:

Regarding Pytorch Optimizer's Schedulers, what actually means the last_epoch argument?

It says

last_epoch (int) – The index of last epoch. Default: -1.

But it doesn't really explains much for those, like me, are just learning about these schedules.

I read most of that documentation, if not all, and I could understand what it does and why.

like image 616
Eduardo Reis Avatar asked Jul 04 '20 03:07

Eduardo Reis


1 Answers

The last_epoch parameter is used when resuming training and you want to start the scheduler where it left off earlier. Its value is increased every time you call .step() of scheduler. The default value of -1 indicates that the scheduler is started from the beginning.

From the docs:

Since step() should be invoked after each batch instead of after each epoch, this number represents the total number of batches computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning.

For example,

>>> import torch
>>> cc = torch.nn.Conv2d(10,10,3)
>>> myoptimizer = torch.optim.Adam(cc.parameters(), lr=0.1)
>>> myscheduler = torch.optim.lr_scheduler.StepLR(myoptimizer,step_size=1, gamma=0.1)
>>> myscheduler.last_epoch, myscheduler.get_lr()
(0, [0.1])
>>> myscheduler.step()
>>> myscheduler.last_epoch, myscheduler.get_lr()
(1, [0.001])
>>> myscheduler.step()
>>> myscheduler.last_epoch, myscheduler.get_lr()
(2, [0.0001])

Now, if you decide to stop the training in the middle, then resume it, you can provide last_epoch parameter to schedular so that it start from where it was left off, not from the beginning again.

>>> mynewscheduler = torch.optim.lr_scheduler.StepLR(myoptimizer,step_size=1, gamma=0.1, last_epoch=myscheduler.last_epoch)
>>> mynewscheduler.last_epoch, mynewscheduler.get_lr()
(3, [1.0000000000000004e-05])
like image 91
kHarshit Avatar answered Sep 19 '22 11:09

kHarshit