Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python - Join Multiple Threads With Timeout

I have multiple Process threads running and I'd like to join all of them together with a timeout parameter. I understand that if no timeout were necessary, I'd be able to write:

for thread in threads:
    thread.join()

One solution I thought of was to use a master thread that joined all the threads together and attempt to join that thread. However, I received the following error in Python:

AssertionError: can only join a child process

The code I have is below.

def join_all(threads):
    for thread in threads:
        thread.join()

if __name__ == '__main__':
    for thread in threads:
        thread.start()

    master = multiprocessing.Process(target=join_all, args=(threads,))
    master.start()
    master.join(timeout=60)
like image 494
tonyduan Avatar asked Dec 15 '22 22:12

tonyduan


2 Answers

You could loop over each thread repeatedly, doing non-blocking checks to see if the thread is done:

import time

def timed_join_all(threads, timeout):
    start = cur_time = time.time()
    while cur_time <= (start + timeout):
        for thread in threads:
            if not thread.is_alive():
                thread.join()
        time.sleep(1)
        cur_time = time.time()

if __name__ == '__main__':
    for thread in threads:
        thread.start()

    timed_join_all(threads, 60)
like image 69
dano Avatar answered Dec 17 '22 13:12

dano


This answer is initially based on that by dano but has a number of changes.

join_all takes a list of threads and a timeout (in seconds) and attempts to join all of the threads. It does this by making a non-blocking call to Thread.join (by setting the timeout to 0, as join with no arguments will never timeout).

Once all the threads have finished (by checking is_alive() on each of them) the loop will exit prematurely.

If some threads are still running by the time the timeout occurs, the function raises a RuntimeError with information about the remaining threads.

import time

def join_all(threads, timeout):
    """
    Args:
        threads: a list of thread objects to join
        timeout: the maximum time to wait for the threads to finish
    Raises:
        RuntimeError: is not all the threads have finished by the timeout
    """
    start = cur_time = time.time()
    while cur_time <= (start + timeout):
        for thread in threads:
            if thread.is_alive():
                thread.join(timeout=0)
        if all(not t.is_alive() for t in threads):
            break
        time.sleep(0.1)
        cur_time = time.time()
    else:
        still_running = [t for t in threads if t.is_alive()]
        num = len(still_running)
        names = [t.name for t in still_running]
        raise RuntimeError('Timeout on {0} threads: {1}'.format(num, names))

if __name__ == '__main__':
    for thread in threads:
        thread.start()

    join_all(threads, 60)

In my usage of this, it was inside a test suite where the threads were dæmonised versions of ExcThread so that if the threads never finished running, it wouldn't matter.

like image 41
Milliams Avatar answered Dec 17 '22 13:12

Milliams