Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python multiprocessing for expensive operation with 2D array

I have a function, say fun(a,b) that is quite expensive, and returns a set of data.

My current method looks like:

a = np.linspace(0,100,300)
b = np.linspace(0,100,300)
A,B = np.meshgrid(a,b)
Y = np.zeros(A.shape)

for i,av in enumerate(a):
  for j, bv in enumerate(b):
    Y[i,j] = fun(av,bv)

(Sort of, I had to muddle things a bit to make it fit). Anyway, the process takes quite a long time, and I was wondering if there was a straightforward way to use my multicore processor to speed things up.

like image 829
RodericDay Avatar asked Feb 18 '23 07:02

RodericDay


1 Answers

There is a great module called multiprocessing that is part of the python standard library. It will spawn processes across as many cores as you like to take advantage of the other CPUs. There is an example of using the Pool object in the docs, below is a shortened version of that example. It will calculate the square of 10 numbers distributing the work load across for work processes and displaying the result.

Simple worker pool

from multiprocessing import Pool

def f(x):
    return x*x

pool = Pool(processes=4)
print pool.map(f, range(10))

Output

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

I had more of a challenge to break up your problem into the same sort of structure. I had to create some intermediary functions to achieve the job. I don't have numpy so I just used lists and dictionaries in place of what you put. You can replace them and try out the code.

More involved scenario

from multiprocessing import Pool
import time, pprint

def fun(av, bv):
    time.sleep(0.1)
    return (av, bv)

def data_stream(a, b):
    for i, av in enumerate(a):
        for j, bv in enumerate(b):
            yield (i, j), (av, bv)

def proxy(args):
    return args[0], fun(*args[1])

a = range(100, 400, 100)
b = range(100, 400, 100)
Y = {}

pool = Pool(processes=4)
results = pool.map(proxy, data_stream(a, b))
for k,v in results:
    Y[k] = v

pprint.pprint(Y)

output

{(0, 0): (100, 100),
 (0, 1): (100, 200),
 (0, 2): (100, 300),
 (1, 0): (200, 100),
 (1, 1): (200, 200),
 (1, 2): (200, 300),
 (2, 0): (300, 100),
 (2, 1): (300, 200),
 (2, 2): (300, 300)}

Performance

In the example I just put a dummy 0.1 second delay to simulate heavy work. But even in this example if you run a Pool with processes=1 it runs in 0.950s with processes=4 it runs in 0.352s. You can use the multiprocessing library in many different ways Pool is just one way. You might want to explore the examples and experiment.

In one of the comments below there was a mention of using the chunksize argument to pool.map to help improve performance. It is important to have a general understanding of what is going on under the hood to really get a handle on performance. Basically all the data you are passing to the other processes needs to be pickled passed to the other process unpickled used and then the result goes through the same process back to the main process. There is an overhead to this inter process communication. Keep this in mind when you experiment.

like image 194
Marwan Alsabbagh Avatar answered Feb 21 '23 01:02

Marwan Alsabbagh