I am trying to using multiprocessing with pool.map to speed execution of a function, but the iterable is not the first argument for that function. Lambdas won't work because they aren't pickleable. I tried to use functools.partial to create a new function, but it fails with a TypeError. Below is a very simple example with the same result. If I switch the argument order to f(i, s1, s2), it works as expected.
Why does the argument order matter here? It is not obvious to me when I read the doc.
What are my options (other than the obvious of changing the original function)?
import multiprocessing
from functools import partial
def f(s1, s2, i):
return [s1] + [s2]*i
def main():
# other code... constants for f aren't known until runtime
pool = multiprocessing.Pool()
func = partial(f, s1='a', s2='c')
for strings in pool.map(func, range(10)):
print(strings)
pool.close()
pool.join()
if __name__ == '__main__':
main()
Update: The best I can think of is to create a wrapper at the module level to switch argument order, then a partial from the wrapper. Doesn't look pretty or seem pythonic at all.
import multiprocessing
from functools import partial
def f(s1, s2, i):
return [s1] + [s2]*i
def wrapper(i, s1, s2):
return f(s1, s2, i)
def main():
# other code... constants for f aren't known until runtime
pool = multiprocessing.Pool()
func = partial(wrapper, s1='foo', s2='bar')
for strings in pool.map(func, range(10)):
print(strings)
pool.close()
pool.join()
if __name__ == '__main__':
main()
The order matters because pool.map is invoking f(i, s1='a', s2='c').
You could write your partial like this:
import multiprocessing
def f(s1, s2, i):
return [s1] + [s2]*i
def f2(i):
return f('a','c',i)
if __name__ == '__main__':
pool = multiprocessing.Pool()
for strings in pool.map(f2, range(10)):
print(strings)
pool.close()
pool.join()
If you are using python3.3, pool.starmap is available to you:
import multiprocessing
from itertools import repeat
def f(s1, s2, i):
return [s1] + [s2]*i
if __name__ == '__main__':
pool = multiprocessing.Pool()
for strings in pool.starmap(f, zip(repeat('a'), repeat('c'), range(10))):
print(strings)
pool.close()
pool.join()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With