Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How not to miss the next element after itertools.takewhile()

Say we wish to process an iterator and want to handle it by chunks.
The logic per chunk depends on previously-calculated chunks, so groupby() does not help.

Our friend in this case is itertools.takewhile():

while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)

The problem is that takewhile() needs to go past the last element that meets the new chunk logic, thus 'eating' the first element for the next chunk.

There are various solutions to that, including wrapping or à la C's ungetc(), etc..
My question is: is there an elegant solution?

like image 592
Paul Oyster Avatar asked Jun 03 '15 09:06

Paul Oyster


2 Answers

takewhile() indeed needs to look at the next element to determine when to toggle behaviour.

You could use a wrapper that tracks the last seen element, and that can be 'reset' to back up one element:

_sentinel = object()

class OneStepBuffered(object):
    def __init__(self, it):
        self._it = iter(it)
        self._last = _sentinel
        self._next = _sentinel
    def __iter__(self):
        return self
    def __next__(self):
        if self._next is not _sentinel:
            next_val, self._next = self._next, _sentinel
            return next_val
        try:
            self._last = next(self._it)
            return self._last
        except StopIteration:
            self._last = self._next = _sentinel
            raise
    next = __next__  # Python 2 compatibility
    def step_back(self):
        if self._last is _sentinel:
            raise ValueError("Can't back up a step")
        self._next, self._last = self._last, _sentinel

Wrap your iterator in this one before using it with takewhile():

myIterator = OneStepBuffered(myIterator)
while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)
    myIterator.step_back()

Demo:

>>> from itertools import takewhile
>>> test_list = range(10)
>>> iterator = OneStepBuffered(test_list)
>>> list(takewhile(lambda i: i < 5, iterator))
[0, 1, 2, 3, 4]
>>> iterator.step_back()
>>> list(iterator)
[5, 6, 7, 8, 9]
like image 129
Martijn Pieters Avatar answered Oct 18 '22 13:10

Martijn Pieters


I had the same problem. You might wish to use itertools.tee or itertools.pairwise (new in Python 3.10) to deal with this, but I didn't think those solutions were very elegant.

The best I found is to just rewrite takewhile. Based heavily on the documentation:

def takewhile_inclusive(predicate, it):
  for x in it:
    if predicate(x):
      yield x
    else:
      yield x
      break

In your loop you can elegantly treat the final element separately using unpacking:

*chunk,lastPiece = takewhile_inclusive(getNewChunkLogic(), myIterator)

You can then chain the last piece:

lastPiece = None
while True:
  *chunk,lastPiece = takewhile_inclusive(getNewChunkLogic(), myIterator)
  if lastPiece is not None:
    myIterator = itertools.chain([lastPiece], myIterator))
  
like image 24
Duncan Avatar answered Oct 18 '22 13:10

Duncan