Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save LambdaLR scheduler in pytorch with lambda function?

Running pytorch 0.4.1 with python 3.6 I encountered this problem:
I cannot torch.save my learning rate scheduler because python won't pickle a lambda function:

lambda1 = lambda epoch: epoch // 30
scheduler = LambdaLR(optimizer, lr_lambda=lambda1)
torch.save(scheduler.state_dict(), 'scheduler.pth.tar')

results with an error

PicklingError: Can't pickle <function <lambda> at 0x7f7583fe92f0>:
attribute lookup <lambda> on __main__ failed

How can I save my scheduler?


I know that using a proper function instead of a lambda function for lambda1 can be saved, but I need a lambda function because I want to be able to control this function when it is defined (for instance I want to be able to change the fixed 30 in the denominator).
How can this be done and still allows me to save the scheduler?

like image 293
Shai Avatar asked Oct 28 '25 17:10

Shai


2 Answers

If one wishes to stay with default behavior of torch.save and torch.load, the lambda function can be replaced with a class, for example:

class LRPolicy(object):
    def __init__(self, rate=30):
        self.rate = rate

    def __call__(self, epoch):
        return epoch // self.rate

The scheduler is now

scheduler = LambdaLR(optimizer, lr_lambda=LRPolicy(rate=30))

Now the scheduler can be torch.saveed and torch.load without alternating the pickling module.

like image 135
Shai Avatar answered Oct 31 '25 10:10

Shai


You can use a dill as the pickle module instead of default pickle.

import dill
torch.save(scheduler.state_dict(), 'scheduler.pth.tar', pickle_module=dill)

and to load

import dill
object = torch.load('scheduler.pth.tar', pickle_module=dill)

See documentation of save for more options.

like image 42
Umang Gupta Avatar answered Oct 31 '25 10:10

Umang Gupta