Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallel recursive function in Python

How do I parallelize a recursive function in Python?

My function looks like this:

def f(x, depth):
    if x==0:
        return ...
    else :
        return [x] + map(lambda x:f(x, depth-1), list_of_values(x))

def list_of_values(x):
    # Heavy compute, pure function

When trying to parallelize it with multiprocessing.Pool.map, Windows opens an infinite number of processes and hangs.

What's a good (preferably simple) way to parallelize it (for a single multicore machine)?

Here is the code that hangs:

from multiprocessing import Pool
pool = pool(processes=4)
def f(x, depth):
    if x==0:
        return ...
    else :
        return [x] + pool.map(lambda x:f(x, depth-1), list_of_values(x))

def list_of_values(x):
    # Heavy compute, pure function
like image 318
Alex Brooks Avatar asked Aug 28 '11 17:08

Alex Brooks


2 Answers

OK, sorry for the problems with this.

I'm going to answer a slightly different question where f() returns the sum of the values in the list. That is because it's not clear to me from your example what the return type of f() would be, and using an integer makes the code simple to understand.

This is complex because there are two different things happening in parallel:

  1. the calculation of the expensive function in the pool
  2. the recursive expansion of f()

I am very careful to only use the pool to calculate the expensive function. In that way we don't get an "explosion" of processes, but because this is asynchronous we need to postpone a lot of work for the callback that the worker calls once the expensive function is done.

More than that, we need to use a countdown latch so that we know when all the separate sub-calls to f() are complete.

There may be a simpler way (I am pretty sure there is, but I need to do other things), but perhaps this gives you an idea of what is possible:

from multiprocessing import Pool, Value, RawArray, RLock
from time import sleep

class Latch:

    '''A countdown latch that lets us wait for a job of "n" parts'''

    def __init__(self, n):
        self.__counter = Value('i', n)
        self.__lock = RLock()

    def decrement(self):
        with self.__lock:
            self.__counter.value -= 1
            print('dec', self.read())
        return self.read() == 0

    def read(self):
        with self.__lock:
            return self.__counter.value

    def join(self):
        while self.read():
            sleep(1)


def list_of_values(x):
    '''An expensive function'''
    print(x, ': thinking...')
    sleep(1)
    print(x, ': thought')
    return list(range(x))


pool = Pool()


def async_f(x, on_complete=None):
    '''Return the sum of the values in the expensive list'''
    if x == 0:
        on_complete(0) # no list, return 0
    else:
        n = x # need to know size of result beforehand
        latch = Latch(n) # wait for n entires to be calculated
        result = RawArray('i', n+1) # where we will assemble the map
        def delayed_map(values):
            '''This is the callback for the pool async process - it runs
               in a separate thread within this process once the
               expensive list has been calculated and orchestrates the
               mapping of f over the result.'''
            result[0] = x # first value in list is x
            for (v, i) in enumerate(values):
                def callback(fx, i=i):
                    '''This is the callback passed to f() and is called when
                       the function completes.  If it is the last of all the
                       calls in the map then it calls on_complete() (ie another
                       instance of this function) for the calling f().'''
                    result[i+1] = fx
                    if latch.decrement(): # have completed list
                        # at this point result contains [x]+map(f, ...)
                        on_complete(sum(result)) # so return sum
                async_f(v, callback)
        # Ask worker to generate list then call delayed_map
        pool.apply_async(list_of_values, [x], callback=delayed_map)


def run():
    '''Tie into the same mechanism as above, for the final value.'''
    result = Value('i')
    latch = Latch(1)
    def final_callback(value):
        result.value = value
        latch.decrement()
    async_f(6, final_callback)
    latch.join() # wait for everything to complete
    return result.value


print(run())

PS: I am using Python 3.2 and the ugliness above is because we are delaying computation of the final results (going back up the tree) until later. It's possible something like generators or futures could simplify things.

Also, I suspect you need a cache to avoid needlessly recalculating the expensive function when called with the same argument as earlier.

See also yaniv's answer - which seems to be an alternative way to reverse the order of the evaluation by being explicit about depth.

like image 117
andrew cooke Avatar answered Oct 10 '22 20:10

andrew cooke


After thinking about this, I found a simple, not complete, but good enough answer:

# A partially parallel solution. Just do the first level of recursion in parallel. It might be enough work to fill all cores.
import multiprocessing

def f_helper(data):
     return f(x=data['x'],depth=data['depth'], recursion_depth=data['recursion_depth'])

def f(x, depth, recursion_depth):
    if depth==0:
        return ...
    else :
        if recursion_depth == 0:
            pool = multiprocessing.Pool(processes=4)
            result = [x] + pool.map(f_helper, [{'x':_x, 'depth':depth-1,  'recursion_depth':recursion_depth+1 } _x in list_of_values(x)])
            pool.close()
        else:
            result = [x] + map(f_helper, [{'x':_x, 'depth':depth-1, 'recursion_depth':recursion_depth+1 } _x in list_of_values(x)])


        return result

def list_of_values(x):
    # Heavy compute, pure function
like image 27
Alex Brooks Avatar answered Oct 10 '22 19:10

Alex Brooks