Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to know a generated sequence is at most a certain length

I want to know whether a generated sequence has fewer than 2 entries.

>>> def sequence():
...     for i in xrange(secret):
...         yield i

My inefficient method is to create a list, and measure its length:

>>> secret = 5
>>> len(list(sequence())) < 2
True

Obviously, this consumes the whole generator.

In my real case the generator could be traversing a large network. I want to do the check without consuming the whole generator, or building a large list.

There's a recipe in the itertools documentation:

def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(islice(iterable, n))

This only builds a list of max length n, which is better.

So I could say:

>>> len(take(2, sequence()) < 2

Is there an even more pythonic, efficient way to do it?

like image 325
Peter Wood Avatar asked Sep 24 '15 08:09

Peter Wood


2 Answers

As of Python 3.4, generators can implement a length hint. If a generator implements this it'll be exposed through the object.__length_hint__() method.

You can test for it with the operator.length_hint() function.

If it is not available, your only option is to consume elements, and your use of the take() recipe is the most efficient way to do that:

from operator import length_hint
from itertools import chain

elements = []
length = length_hint(gen, None)
if length is None:
    elements = list(take(2, gen))
    length = len(elements)
if length >= 2:
    # raise an error
# use elements, then gen
gen = chain(elements, gen)
like image 110
Martijn Pieters Avatar answered Nov 15 '22 18:11

Martijn Pieters


The solution using take uses islice, builds a list and takes the length of it:

>>> from itertools import islice
>>> len(list(islice(sequence(), 2))
2

To avoid creating the list we can use sum:

>>> sum(1 for _ in islice(sequence(), 2)
2

This takes about 70% of the time:

>>> timeit('len(list(islice(xrange(1000), 2)))', 'from itertools import islice')
 1.089650974650752

>>> timeit('sum(1 for _ in islice(xrange(1000), 2))', 'from itertools import islice')
0.7579448552500647

Wrapping it up:

>>> def at_most(n, elements):
...     return sum(1 for _ in islice(elements, n + 1)) <= n

>>> at_most(5, xrange(5))
True

>>> at_most(2, xrange(5))
False
like image 43
Peter Wood Avatar answered Nov 15 '22 18:11

Peter Wood