Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is a `for` loop so much faster to count True values?

sum is quite fast, but sum isn't the cause of the slowdown. Three primary factors contribute to the slowdown:

  • The use of a generator expression causes overhead for constantly pausing and resuming the generator.
  • Your generator version adds unconditionally instead of only when the digit is even. This is more expensive when the digit is odd.
  • Adding booleans instead of ints prevents sum from using its integer fast path.

Generators offer two primary advantages over list comprehensions: they take a lot less memory, and they can terminate early if not all elements are needed. They are not designed to offer a time advantage in the case where all elements are needed. Suspending and resuming a generator once per element is pretty expensive.

If we replace the genexp with a list comprehension:

In [66]: def f1(x):
   ....:     return sum(c in '02468' for c in str(x))
   ....: 
In [67]: def f2(x):
   ....:     return sum([c in '02468' for c in str(x)])
   ....: 
In [68]: x = int('1234567890'*50)
In [69]: %timeit f1(x)
10000 loops, best of 5: 52.2 µs per loop
In [70]: %timeit f2(x)
10000 loops, best of 5: 40.5 µs per loop

we see an immediate speedup, at the cost of wasting a bunch of memory on a list.


If you look at your genexp version:

def count_even_digits_spyr03_sum(n):
    return sum(c in "02468" for c in str(n))

you'll see it has no if. It just throws booleans into sum. In constrast, your loop:

def count_even_digits_spyr03_for(n):
    count = 0
    for c in str(n):
        if c in "02468":
            count += 1
    return count

only adds anything if the digit is even.

If we change the f2 defined earlier to also incorporate an if, we see another speedup:

In [71]: def f3(x):
   ....:     return sum([True for c in str(x) if c in '02468'])
   ....: 
In [72]: %timeit f3(x)
10000 loops, best of 5: 34.9 µs per loop

f1, identical to your original code, took 52.2 µs, and f2, with just the list comprehension change, took 40.5 µs.


It probably looked pretty awkward using True instead of 1 in f3. That's because changing it to 1 activates one final speedup. sum has a fast path for integers, but the fast path only activates for objects whose type is exactly int. bool doesn't count. This is the line that checks that items are of type int:

if (PyLong_CheckExact(item)) {

Once we make the final change, changing True to 1:

In [73]: def f4(x):
   ....:     return sum([1 for c in str(x) if c in '02468'])
   ....: 
In [74]: %timeit f4(x)
10000 loops, best of 5: 33.3 µs per loop

we see one last small speedup.


So after all that, do we beat the explicit loop?

In [75]: def explicit_loop(x):
   ....:     count = 0
   ....:     for c in str(x):
   ....:         if c in '02468':
   ....:             count += 1
   ....:     return count
   ....: 
In [76]: %timeit explicit_loop(x)
10000 loops, best of 5: 32.7 µs per loop

Nope. We've roughly broken even, but we're not beating it. The big remaining problem is the list. Building it is expensive, and sum has to go through the list iterator to retrieve elements, which has its own cost (though I think that part is pretty cheap). Unfortunately, as long as we're going through the test-digits-and-call-sum approach, we don't have any good way to get rid of the list. The explicit loop wins.

Can we go further anyway? Well, we've been trying to bring the sum closer to the explicit loop so far, but if we're stuck with this dumb list, we could diverge from the explicit loop and just call len instead of sum:

def f5(x):
    return len([1 for c in str(x) if c in '02468'])

Testing digits individually isn't the only way we can try to beat the loop, too. Diverging even further from the explicit loop, we can also try str.count. str.count iterates over a string's buffer directly in C, avoiding a lot of wrapper objects and indirection. We need to call it 5 times, making 5 passes over the string, but it still pays off:

def f6(x):
    s = str(x)
    return sum(s.count(c) for c in '02468')

Unfortunately, this is the point when the site I was using for timing stuck me in the "tarpit" for using too many resources, so I had to switch sites. The following timings are not directly comparable with the timings above:

>>> import timeit
>>> def f(x):
...     return sum([1 for c in str(x) if c in '02468'])
... 
>>> def g(x):
...     return len([1 for c in str(x) if c in '02468'])
... 
>>> def h(x):
...     s = str(x)
...     return sum(s.count(c) for c in '02468')
... 
>>> x = int('1234567890'*50)
>>> timeit.timeit(lambda: f(x), number=10000)
0.331528635986615
>>> timeit.timeit(lambda: g(x), number=10000)
0.30292080697836354
>>> timeit.timeit(lambda: h(x), number=10000)
0.15950968803372234
>>> def explicit_loop(x):
...     count = 0
...     for c in str(x):
...         if c in '02468':
...             count += 1
...     return count
... 
>>> timeit.timeit(lambda: explicit_loop(x), number=10000)
0.3305045129964128

If we use dis.dis(), we can see how the functions actually behave.

count_even_digits_spyr03_for():

  7           0 LOAD_CONST               1 (0)
              3 STORE_FAST               0 (count)

  8           6 SETUP_LOOP              42 (to 51)
              9 LOAD_GLOBAL              0 (str)
             12 LOAD_GLOBAL              1 (n)
             15 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             18 GET_ITER
        >>   19 FOR_ITER                28 (to 50)
             22 STORE_FAST               1 (c)

  9          25 LOAD_FAST                1 (c)
             28 LOAD_CONST               2 ('02468')
             31 COMPARE_OP               6 (in)
             34 POP_JUMP_IF_FALSE       19

 10          37 LOAD_FAST                0 (count)
             40 LOAD_CONST               3 (1)
             43 INPLACE_ADD
             44 STORE_FAST               0 (count)
             47 JUMP_ABSOLUTE           19
        >>   50 POP_BLOCK

 11     >>   51 LOAD_FAST                0 (count)
             54 RETURN_VALUE

We can see that there's only one function call, that's to str() at the beginning:

9 LOAD_GLOBAL              0 (str)
...
15 CALL_FUNCTION            1 (1 positional, 0 keyword pair)

Rest of it is highly optimized code, using jumps, stores, and inplace adding.

What comes to count_even_digits_spyr03_sum():

 14           0 LOAD_GLOBAL              0 (sum)
              3 LOAD_CONST               1 (<code object <genexpr> at 0x10dcc8c90, file "test.py", line 14>)
              6 LOAD_CONST               2 ('count2.<locals>.<genexpr>')
              9 MAKE_FUNCTION            0
             12 LOAD_GLOBAL              1 (str)
             15 LOAD_GLOBAL              2 (n)
             18 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             21 GET_ITER
             22 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             25 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             28 RETURN_VALUE

While I can't perfectly explain the differences, we can clearly see that there are more function calls (probably sum() and in(?)), which make the code run much slower than executing the machine instructions directly.


@MarkusMeskanen's answer has the right bits – function calls are slow, and both genexprs and listcomps are basically function calls.

Anyway, to be pragmatic:

Using str.count(c) is faster, and this related answer of mine about strpbrk() in Python could make things faster still.

def count_even_digits_spyr03_count(n):
    s = str(n)
    return sum(s.count(c) for c in "02468")


def count_even_digits_spyr03_count_unrolled(n):
    s = str(n)
    return s.count("0") + s.count("2") + s.count("4") + s.count("6") + s.count("8")

Results:

string length: 502
count_even_digits_spyr03_list 0.04157966522
count_even_digits_spyr03_sum 0.05678154459
count_even_digits_spyr03_for 0.036128606150000006
count_even_digits_spyr03_count 0.010441866129999991
count_even_digits_spyr03_count_unrolled 0.009662931009999999

There are a few differences that actually contribute to the observed performance differences. I aim to give a high-level overview of these differences but try not to go too much into the low-level details or possible improvements. For the benchmarks I use my own package simple_benchmark.

Generators vs. for loops

Generators and generator expressions are syntactic sugar that can be used instead of writing iterator classes.

When you write a generator like:

def count_even(num):
    s = str(num)
    for c in s:
        yield c in '02468'

Or a generator expression:

(c in '02468' for c in str(num))

That will be transformed (behind the scenes) into a state machine that is accessible through an iterator class. In the end it will be roughly equivalent to (although the actual code generated around a generator will be faster):

class Count:
    def __init__(self, num):
        self.str_num = iter(str(num))

    def __iter__(self):
        return self

    def __next__(self):
        c = next(self.str_num)
        return c in '02468'

So a generator will always have one additional layer of indirection. That means that advancing the generator (or generator expression or iterator) means that you call __next__ on the iterator that is generated by the generator which itself calls __next__ on the object you actually want to iterate over. But it also has some overhead because you actually need to create one additional "iterator instance". Typically these overheads are negligible if you do anything substantial in each iteration.

Just to provide an example how much overhead a generator imposes compared to a manual loop:

import matplotlib.pyplot as plt
from simple_benchmark import BenchmarkBuilder
%matplotlib notebook

bench = BenchmarkBuilder()

@bench.add_function()
def iteration(it):
    for i in it:
        pass

@bench.add_function()
def generator(it):
    it = (item for item in it)
    for i in it:
        pass

@bench.add_arguments()
def argument_provider():
    for i in range(2, 15):
        size = 2**i
        yield size, [1 for _ in range(size)]

plt.figure()
result = bench.run()
result.plot()

enter image description here

Generators vs. List comprehensions

Generators have the advantage that they don't create a list, they "produce" the values one-by-one. So while a generator has the overhead of the "iterator class" it can save the memory for creating an intermediate list. It's a trade-off between speed (list comprehension) and memory (generators). This has been discussed in various posts around StackOverflow so I don't want to go into much more detail here.

import matplotlib.pyplot as plt
from simple_benchmark import BenchmarkBuilder
%matplotlib notebook

bench = BenchmarkBuilder()

@bench.add_function()
def generator_expression(it):
    it = (item for item in it)
    for i in it:
        pass

@bench.add_function()
def list_comprehension(it):
    it = [item for item in it]
    for i in it:
        pass

@bench.add_arguments('size')
def argument_provider():
    for i in range(2, 15):
        size = 2**i
        yield size, list(range(size))

plt.figure()
result = bench.run()
result.plot()

enter image description here

sum should be faster than manual iteration

Yes, sum is indeed faster than an explicit for loop. Especially if you iterate over integers.

import matplotlib.pyplot as plt
from simple_benchmark import BenchmarkBuilder
%matplotlib notebook

bench = BenchmarkBuilder()

@bench.add_function()
def my_sum(it):
    sum_ = 0
    for i in it:
        sum_ += i
    return sum_

bench.add_function()(sum)

@bench.add_arguments()
def argument_provider():
    for i in range(2, 15):
        size = 2**i
        yield size, [1 for _ in range(size)]

plt.figure()
result = bench.run()
result.plot()

enter image description here

String methods vs. Any kind of Python loop

To understand the performance difference when using string methods like str.count compared to loops (explicit or implicit) is that strings in Python are actually stored as values in an (internal) array. That means a loop doesn't actually call any __next__ methods, it can use a loop directly over the array, this will be significantly faster. However it also imposes a method lookup and a method call on the string, that's why it's slower for very short numbers.

Just to provide a small comparison how long it takes to iterate a string vs. how long it takes Python to iterate over the internal array:

import matplotlib.pyplot as plt
from simple_benchmark import BenchmarkBuilder
%matplotlib notebook

bench = BenchmarkBuilder()

@bench.add_function()
def string_iteration(s):
    # there is no "a" in the string, so this iterates over the whole string
    return 'a' in s  

@bench.add_function()
def python_iteration(s):
    for c in s:
        pass

@bench.add_arguments('string length')
def argument_provider():
    for i in range(2, 20):
        size = 2**i
        yield size, '1'*size

plt.figure()
result = bench.run()
result.plot()

In this benchmark it's ~200 times faster to let Python do the iteration over the string than to iterate over the string with a for loop.

enter image description here

Why do all of them converge for large numbers?

This is actually because the number to string conversion will be dominant there. So for really huge numbers you're essentially just measuring how long it takes to convert that number to a string.

You'll see the difference if you compare the versions that take a number and convert it to a string with the one that take the converted number (I use the functions from another answer here to illustrate that). Left is the number-benchmark and on the right is the benchmark that takes the strings - also the y-axis is the same for both plots: enter image description here

As you can see the benchmarks for the functions that take the string are significantly faster for large numbers than the ones that take a number and convert them to a string inside. This indicates that the string-conversion is the "bottleneck" for large numbers. For convenience I also included a benchmark only doing the string conversion to the left plot (which becomes significant/dominant for large numbers).

%matplotlib notebook

from simple_benchmark import BenchmarkBuilder
import matplotlib.pyplot as plt
import random

bench1 = BenchmarkBuilder()

@bench1.add_function()
def f1(x):
    return sum(c in '02468' for c in str(x))

@bench1.add_function()
def f2(x):
    return sum([c in '02468' for c in str(x)])

@bench1.add_function()
def f3(x):
    return sum([True for c in str(x) if c in '02468'])    

@bench1.add_function()
def f4(x):
    return sum([1 for c in str(x) if c in '02468'])

@bench1.add_function()
def explicit_loop(x):
    count = 0
    for c in str(x):
        if c in '02468':
            count += 1
    return count

@bench1.add_function()
def f5(x):
    s = str(x)
    return sum(s.count(c) for c in '02468')

bench1.add_function()(str)

@bench1.add_arguments(name='number length')
def arg_provider():
    for i in range(2, 15):
        size = 2 ** i
        yield (2**i, int(''.join(str(random.randint(0, 9)) for _ in range(size))))


bench2 = BenchmarkBuilder()

@bench2.add_function()
def f1(x):
    return sum(c in '02468' for c in x)

@bench2.add_function()
def f2(x):
    return sum([c in '02468' for c in x])

@bench2.add_function()
def f3(x):
    return sum([True for c in x if c in '02468'])    

@bench2.add_function()
def f4(x):
    return sum([1 for c in x if c in '02468'])

@bench2.add_function()
def explicit_loop(x):
    count = 0
    for c in x:
        if c in '02468':
            count += 1
    return count

@bench2.add_function()
def f5(x):
    return sum(x.count(c) for c in '02468')

@bench2.add_arguments(name='number length')
def arg_provider():
    for i in range(2, 15):
        size = 2 ** i
        yield (2**i, ''.join(str(random.randint(0, 9)) for _ in range(size)))

f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
b1 = bench1.run()
b2 = bench2.run()
b1.plot(ax=ax1)
b2.plot(ax=ax2)
ax1.set_title('Number')
ax2.set_title('String')