Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Confused about the lambda expression in python

Tags:

python

lambda

I understand the normal lambda expression, such as

g = lambda x: x**2

However, for some complex ones, I am a little confused about them. For example:

for split in ['train', 'test']:
    sets = (lambda split=split: newspaper(split, newspaper_devkit_path))

def get_imdb():
    return sets()

Where newspaper is a function. I was wondering what actually the sets is and why the get_imdb function can return the value sets()

Thanks for your help!

Added: The codes are actually from here factory.py

like image 210
Panfeng Li Avatar asked Oct 30 '22 01:10

Panfeng Li


1 Answers

sets is being assigned a lambda that is not really supposed to accept inputs, which you see from the way it is invoked. Lambdas in general behave like normal functions, and can therefore be assigned to variables like g or sets. The definition of sets is surrounded by an extra set of parentheses for no apparent reason. You can ignore those outer parens.

Lambdas can have all the same types of positional, keyword and default arguments a normal function can. The lambda sets has a default parameter named split. This is a common idiom to ensure that sets in each iteration of the loop gets the value of split corresponding to that iteration rather than just the one from the last iteration in all cases.

Without a default parameter, split would be evaluated within the lambda based on the namespace at the time it was called. Once the loop completes, split in the outer function's namespace will just be the last value it had for the loop.

Default parameters are evaluated immediately when a function object is created. This means that the value of the default parameter split will be wherever it is in the iteration of the loop that creates it.

Your example is a bit misleading because it discards all the actual values of sets besides the last one, making the default parameter to the lambda meaningless. Here is an example illustrating what happens if you keep all the lambdas. First with the default parameter:

sets = []
for split in ['train', 'test']:
    sets.append(lambda split=split: split)
print([fn() for fn in sets])

I have truncated the lambdas to just return their input parameter for purposes of illustration. This example will print ['train', 'test'], as expected.

If you do the same thing without the default parameter, the output will be ['test', 'test'] instead:

sets = []
for split in ['train', 'test']:
    sets.append(lambda: split)
print([fn() for fn in sets])

This is because 'test' is the value of split when all the lambdas get evaluated.

like image 94
Mad Physicist Avatar answered Nov 15 '22 07:11

Mad Physicist