Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Argument order with python functools partial

I am trying to using multiprocessing with pool.map to speed execution of a function, but the iterable is not the first argument for that function. Lambdas won't work because they aren't pickleable. I tried to use functools.partial to create a new function, but it fails with a TypeError. Below is a very simple example with the same result. If I switch the argument order to f(i, s1, s2), it works as expected.

Why does the argument order matter here? It is not obvious to me when I read the doc.

What are my options (other than the obvious of changing the original function)?

import multiprocessing
from functools import partial


def f(s1, s2, i):
    return [s1] + [s2]*i

def main():
    # other code... constants for f aren't known until runtime
    pool = multiprocessing.Pool()
    func = partial(f, s1='a', s2='c')
    for strings in pool.map(func, range(10)):
        print(strings)
    pool.close()
    pool.join()

if __name__ == '__main__':
    main()

Update: The best I can think of is to create a wrapper at the module level to switch argument order, then a partial from the wrapper. Doesn't look pretty or seem pythonic at all.

import multiprocessing
from functools import partial


def f(s1, s2, i):
    return [s1] + [s2]*i

def wrapper(i, s1, s2):
    return f(s1, s2, i)

def main():
    # other code... constants for f aren't known until runtime
    pool = multiprocessing.Pool()
    func = partial(wrapper, s1='foo', s2='bar')
    for strings in pool.map(func, range(10)):
        print(strings)
    pool.close()
    pool.join()

if __name__ == '__main__':
    main()
like image 943
user2133814 Avatar asked Mar 15 '26 12:03

user2133814


1 Answers

The order matters because pool.map is invoking f(i, s1='a', s2='c'). You could write your partial like this:

import multiprocessing

def f(s1, s2, i):
    return [s1] + [s2]*i

def f2(i):
    return f('a','c',i)

if __name__ == '__main__':
    pool = multiprocessing.Pool()
    for strings in pool.map(f2, range(10)):
        print(strings)
    pool.close()
    pool.join()

If you are using python3.3, pool.starmap is available to you:

import multiprocessing
from itertools import repeat

def f(s1, s2, i):
    return [s1] + [s2]*i

if __name__ == '__main__':
    pool = multiprocessing.Pool()
    for strings in pool.starmap(f, zip(repeat('a'), repeat('c'), range(10))):
        print(strings)
    pool.close()
    pool.join()
like image 57
Fabricator Avatar answered Mar 17 '26 01:03

Fabricator



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!