Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrapping asyncio.gather in a timeout

I've seen asyncio.gather vs asyncio.wait, but am not sure if that addresses this particular question. What I'm looking to do is wrap the asyncio.gather() coroutine in asyncio.wait_for(), with a timeout argument. I also need to satisfy these conditions:

  • return_exceptions=True (from asyncio.gather()) - rather than propagating exceptions to the task that awaits on gather(), I want to include exceptions instances in the results
  • Order: retain the property of asyncio.gather() that the order of results is the same as the order of the input. (Or otherwise, map the output back to the input.). asyncio.wait_for() fails this criteria and I'm not sure of ideal way to achieve it.

The timeout is for the entire asyncio.gather() across the list of awaitables--if they get caught in the timeout or return an exception, either of those cases should just place an exception instance in the result list.

Consider this setup:

>>> import asyncio
>>> import random
>>> from time import perf_counter
>>> from typing import Iterable
>>> from pprint import pprint
>>> 
>>> async def coro(i, threshold=0.4):
...     await asyncio.sleep(i)
...     if i > threshold:
...         # For illustration's sake - some coroutines may raise,
...         # and we want to accomodate that and just test for exception
...         # instances in the results of asyncio.gather(return_exceptions=True)
...         raise Exception("i too high")
...     return i
... 
>>> async def main(n, it: Iterable):
...     res = await asyncio.gather(
...         *(coro(i) for i in it),
...         return_exceptions=True
...     )
...     return res
... 
>>> 
>>> random.seed(444)
>>> n = 10
>>> it = [random.random() for _ in range(n)]
>>> start = perf_counter()
>>> res = asyncio.run(main(n, it=it))
>>> elapsed = perf_counter() - start
>>> print(f"Done main({n}) in {elapsed:0.2f} seconds")  # Expectation: ~1 seconds
Done main(10) in 0.86 seconds
>>> pprint(dict(zip(it, res)))
{0.01323751590501987: 0.01323751590501987,
 0.07422124156714727: 0.07422124156714727,
 0.3088946587429545: 0.3088946587429545,
 0.3113884366691503: 0.3113884366691503,
 0.4419557492849159: Exception('i too high'),
 0.4844375347808497: Exception('i too high'),
 0.5796792804615848: Exception('i too high'),
 0.6338658027451068: Exception('i too high'),
 0.7426396870165088: Exception('i too high'),
 0.8614799253779063: Exception('i too high')}

The program above, with n = 10, has an exected runtime of .5 seconds plus a bit of overhead when run asynchronously. (random.random() will be uniformly distributed in [0, 1).)

Let's say I want to impose that as the timeout, on the entire operation (i.e. on the coroutine main()):

timeout = 0.5

Now, I can use asyncio.wait(), but the problem is that the results are set objects and so definitely can't guarantee the sorted return value property of asyncio.gather():

>>> async def main(n, it, timeout) -> tuple:
...     tasks = [asyncio.create_task(coro(i)) for i in it]
...     done, pending = await asyncio.wait(tasks, timeout=timeout)
...     return done, pending
... 
>>> timeout = 0.5
>>> random.seed(444)
>>> it = [random.random() for _ in range(n)]
>>> start = perf_counter()
>>> done, pending = asyncio.run(main(n, it=it, timeout=timeout))
>>> for i in pending:
...     i.cancel()
>>> elapsed = perf_counter() - start
>>> print(f"Done main({n}) in {elapsed:0.2f} seconds")
Done main(10) in 0.50 seconds
>>> done
{<Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>, <Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>, <Task finished coro=<coro() done, defined at <stdin>:1> result=0.3088946587429545>, <Task finished coro=<coro() done, defined at <stdin>:1> result=0.3113884366691503>, <Task finished coro=<coro() done, defined at <stdin>:1> result=0.01323751590501987>, <Task finished coro=<coro() done, defined at <stdin>:1> result=0.07422124156714727>}
>>> pprint(done)
{<Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.3113884366691503>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.07422124156714727>,
 <Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.01323751590501987>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.3088946587429545>}
>>> pprint(pending)
{<Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>}

As stated above, the issue is that I seemingly can't map back task instances to the inputs in iterable. They task ids are effectively lost inside a function scope with tasks = [asyncio.create_task(coro(i)) for i in it]. Is there a Pythonic way/use of asyncio API to mimic the behavior of asyncio.gather() here?

like image 267
Brad Solomon Avatar asked Jan 29 '19 18:01

Brad Solomon


1 Answers

Taking a look at the underlying _wait() coroutine, this coroutine gets passed a list of tasks and will modify the state of those tasks in place. This means that, within the scope of main(), the tasks from tasks = [asyncio.create_task(coro(i)) for i in it] will be modified by the call to await asyncio.wait(tasks, timeout=timeout). Instead of returning a (done, pending) tuple, one workaround is to just return tasks themselves, which retains order with the input it. wait()/_wait() just separates the tasks into done/pending subsets and in this case we can discard those subsets and use the whole lists of tasks whose elements have been altered.

There are three possible tasks states in this case:

  • A task returned a valid result (coro()) didn't raise an exception, and it finished under the timeout. Its .cancelled() will be False, and it has a valid .result() that is not an exception instance
  • A task got hit with the timeout before having a chance to return either a result or raise an exception. It will show .cancelled() and its .exception() will raise a CancelledError
  • A task that was allowed time to finished and raised an exception from coro(); it will show .cancelled() as False and its exception() will raise

(All of this is laid out in asyncio/futures.py.)


Illustration:

>>> # imports/other code snippets - see question
>>> async def main(n, it, timeout) -> tuple:
...     tasks = [asyncio.create_task(coro(i)) for i in it]
...     await asyncio.wait(tasks, timeout=timeout)
...     return tasks  # *not* (done, pending)

>>> timeout = 0.5
>>> random.seed(444)
>>> n = 10
>>> it = [random.random() for _ in range(n)]
>>> start = perf_counter()
>>> tasks = asyncio.run(main(n, it=it, timeout=timeout))
>>> elapsed = perf_counter() - start
>>> print(f"Done main({n}) in {elapsed:0.2f} seconds")
Done main(10) in 0.50 seconds

>>> pprint(tasks)
[<Task finished coro=<coro() done, defined at <stdin>:1> result=0.3088946587429545>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.01323751590501987>,
 <Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>,
 <Task finished coro=<coro() done, defined at <stdin>:1> exception=Exception('i too high')>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.3113884366691503>,
 <Task finished coro=<coro() done, defined at <stdin>:1> result=0.07422124156714727>,
 <Task cancelled coro=<coro() done, defined at <stdin>:1>>]

Now to apply the logic from above, which lets res retain order corresponding to the inputs:

>>> res = []
>>> for t in tasks:
...     try:
...         r = t.result()
...     except Exception as e:
...         res.append(e)
...     else:
...         res.append(r)
>>> pprint(res)
[0.3088946587429545,
 0.01323751590501987,
 Exception('i too high'),
 CancelledError(),
 CancelledError(),
 CancelledError(),
 Exception('i too high'),
 0.3113884366691503,
 0.07422124156714727,
 CancelledError()]
>>> dict(zip(it, res))
{0.3088946587429545: 0.3088946587429545,
 0.01323751590501987: 0.01323751590501987,
 0.4844375347808497: Exception('i too high'),
 0.8614799253779063: concurrent.futures._base.CancelledError(),
 0.7426396870165088: concurrent.futures._base.CancelledError(),
 0.6338658027451068: concurrent.futures._base.CancelledError(),
 0.4419557492849159: Exception('i too high'),
 0.3113884366691503: 0.3113884366691503,
 0.07422124156714727: 0.07422124156714727,
 0.5796792804615848: concurrent.futures._base.CancelledError()}
like image 122
Brad Solomon Avatar answered Nov 13 '22 09:11

Brad Solomon