Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python Multiprocessing Pool as Decorator

I'm working on code where I frequently have to use python's multiprocessing Pool class. This results in a ton of code that looks like this:

import time
from multiprocessing import Pool
from functools import partial

def test_func(x):
    time.sleep(1)
    return x

def test_func_parallel(iterable, processes):
    p = Pool(processes=processes)
    output = p.map(test_func, iterable)
    p.close()
    return output

This can be made more general:

def parallel(func, iterable, **kwargs):
    func = partial(func, **kwargs)
    p = Pool(processes=6)
    out = p.map(func, iterable)
    p.close()
    return out

This works, but adding a parallel wrapper to every other function complicates the code. What I'd really like is to get this working as a decorator. Something like this:

def parallel(num_processes):
    def parallel_decorator(func, num_processes=num_processes):
        def parallel_wrapper(iterable, **kwargs):
            func = partial(func, **kwargs)
            p = Pool(processes=num_processes)
            output = p.map(func, iterable)
            p.close()
            return output

        return parallel_wrapper
    return parallel_decorator

Which could be used as follows:

@parallel(6)
def test_func(x):
    time.sleep(1)
    return x

This fails for pickle reasons

Can't pickle <function test1 at 0x117473268>: it's not the same object as __main__.test1

I've read a few posts on related issues, but they all implement a solution where the multiprocessing is executed outside the decorator. Does anyone know a way to make this work?

like image 467
Karl Avatar asked Apr 16 '21 20:04

Karl


1 Answers

If you don't mind to not use the syntactic sugar for decorators (@ symbol), something like this should work:

import functools
import time

from multiprocessing import Pool


def parallel(func=None, **options):
    if func is None:
        return functools.partial(parallel, **options)

    def wrapper(iterable, **kwargs):
        processes = options["processes"]

        with Pool(processes) as pool:
            result = pool.map(func, iterable)

        return result

    return wrapper


def test(i):
    time.sleep(1)
    print(f"{i}: {i * i}")

test_parallel = parallel(test, processes=6)


def main():
    test_parallel(range(10))


if __name__ == "__main__":
    main()
like image 132
HTF Avatar answered Oct 19 '22 01:10

HTF