Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to apply a function to a 2D numpy array with multiprocessing

Suppose I have the following function:

def f(x,y):
    return x*y

How do I apply the funtion to each element in an NxM 2D numpy array using the multiprocessing module? Using serial iteration, the code might look as follows:

import numpy as np
N = 10
M = 12
results = np.zeros(shape=(N,M))
for x in range(N):
    for y in range(M):
        results[x,y] = f(x,y)
like image 702
TraxusIV Avatar asked Feb 10 '23 08:02

TraxusIV


1 Answers

Here's how you might parallelize your example function using multiprocesssing. I've also included an almost identical pure Python function that uses non-parallel for loops, and a numpy one-liner that achieves the same result:

import numpy as np
from multiprocessing import Pool


def f(x,y):
    return x * y

# this helper function is needed because map() can only be used for functions
# that take a single argument (see http://stackoverflow.com/q/5442910/1461210)
def splat_f(args):
    return f(*args)

# a pool of 8 worker processes
pool = Pool(8)

def parallel(M, N):
    results = pool.map(splat_f, ((i, j) for i in range(M) for j in range(N)))
    return np.array(results).reshape(M, N)

def nonparallel(M, N):
    out = np.zeros((M, N), np.int)
    for i in range(M):
        for j in range(N):
            out[i, j] = f(i, j)
    return out

def broadcast(M, N):
    return np.prod(np.ogrid[:M, :N])

Now let's look at the performance:

%timeit parallel(1000, 1000)
# 1 loops, best of 3: 1.67 s per loop

%timeit nonparallel(1000, 1000)
# 1 loops, best of 3: 395 ms per loop

%timeit broadcast(1000, 1000)
# 100 loops, best of 3: 2 ms per loop

The non-parallel pure Python version beats the parallelized version by a factor of about 4, and the version using numpy array broadcasting absolutely crushes the other two.

The problem is that starting and stopping Python subprocesses carries quite a lot of overhead, and your test function is so trivial that each worker thread spends only a tiny proportion of its lifetime doing useful work. Multiprocessing only makes sense if each thread has a substantial amount of work to do before it is killed. You might, for example, give each worker a bigger chunk of the output array to compute (try messing around with the chunksize= parameter to pool.map()), but with such a trivial example I doubt you'll see a big improvement.

I don't know what your actual code looks like - maybe your function is big and expensive enough to warrant using multiprocessing. However, I would bet that there are much better ways to improve its performance.

like image 170
ali_m Avatar answered Feb 12 '23 23:02

ali_m