Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Anaphoric list comprehension in Python

Consider the following toy example:

>>> def square(x): return x*x
... 
>>> [square(x) for x in range(12) if square(x) > 50]
[64, 81, 100, 121]

I have to call square(x) twice in the list comprehension. The duplication is ugly, bug-prone (it's easy to change only one of the two calls when modifying the code), and inefficient.

Of course I can do this:

>>> squares = [square(x) for x in range(12)]
>>> [s for s in squares if s > 50]
[64, 81, 100, 121]

or this:

[s for s in [square(x) for x in range(12)] if s > 50]

These are both livable, but it feels as though there might be a way to do it all in a single statement without nesting the two list comprehensions, which I know I'll have to stare it for a while next time I'm reading the code just to figure out what's going on. Is there a way?

I think a fair question to ask of me would be what I imagine such syntax could look like. Here are two ideas, but neither feels idiomatic in Python (nor do they work). They are inspired by anaphoric macros in Lisp.

[square(x) for x in range(12) if it > 50]
[it=square(x) for x in range(12) if it > 50]
like image 834
kuzzooroo Avatar asked Nov 20 '13 00:11

kuzzooroo


2 Answers

You should use a generator:

[s for s in (square(x) for x in range(12)) if s > 50]

This avoids creating an intermediate unfiltered list of squares.

like image 101
arshajii Avatar answered Sep 28 '22 03:09

arshajii


Here is a comparison of nested generator vs "chained" list comps vs calculating twice

$ python -m timeit "[s for n in range(12) for s in [n * n] if s > 50]"
100000 loops, best of 3: 2.48 usec per loop
$ python -m timeit "[s for s in (x * x for x in range(12)) if s > 50]"
1000000 loops, best of 3: 1.89 usec per loop
$ python -m timeit "[n * n for n in range(12) if n * n > 50]"
1000000 loops, best of 3: 1.1 usec per loop

$ pypy -m timeit "[s for n in range(12) for s in [n * n] if s > 50]"
1000000 loops, best of 3: 0.211 usec per loop
$ pypy -m timeit "[s for s in (x * x for x in range(12)) if s > 50]"
1000000 loops, best of 3: 0.359 usec per loop
$ pypy -m timeit "[n * n for n in range(12) if n * n > 50]"
10000000 loops, best of 3: 0.0834 usec per loop

I used n * n instead of square(n) because it was convenient and removes the function call overhead from the benckmark

TLDR: For simple cases it may be best to just duplicate the calculation.

like image 37
John La Rooy Avatar answered Sep 28 '22 05:09

John La Rooy