Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Alternative use patterns for python multiprocessing avoiding proliferation of global state?

This (enormously simplified example) works fine (Python 2.6.6, Debian Squeeze):

from multiprocessing import Pool
import numpy as np

src=None

def process(row):
    return np.sum(src[row])

def main():
    global src
    src=np.ones((100,100))

    pool=Pool(processes=16)
    rows=pool.map(process,range(100))
    print rows

if __name__ == "__main__":
    main()

however, after years of being taught global state bad!!!, all my instincts are telling me I really really would rather be writing something closer to:

from multiprocessing import Pool
import numpy as np

def main():
    src=np.ones((100,100))

    def process(row):
        return np.sum(src[row])

    pool=Pool(processes=16)
    rows=pool.map(process,range(100))
    print rows

if __name__ == "__main__":
    main()

but of course that doesn't work (hangs up unable to pickle something).

The example here is trivial, but by the time you add multiple "process" functions, and each of those is dependent on multiple additional inputs... well it all becomes a bit reminiscent of something written in BASIC 30 years ago. Trying to use classes to at least aggregate the state with the appropriate functions seems an obvious solution, but doesn't seem to be that easy in practice.

Is there some recommended pattern or style for using multiprocessing.pool which will avoid the proliferation of global state to support each function I want to parallel map over ?

How do experienced "multiprocessing pros" deal with this ?

Update: Note that I'm actually interested in processing much bigger arrays, so variations on the above which pickle src each call/iteration aren't nearly as good as ones which fork it into the pool's worker processes.

like image 667
timday Avatar asked Apr 14 '12 09:04

timday


1 Answers

You could always pass a callable object like this, then the object can containe the shared state:

from multiprocessing import Pool
import numpy as np

class RowProcessor(object):
    def __init__(self, src):
        self.__src = src
    def __call__(self, row):
        return np.sum(self.__src[row])

def main():
    src=np.ones((100,100))
    p = RowProcessor(src)

    pool=Pool(processes=16)
    rows = pool.map(p, range(100))
    print rows

if __name__ == "__main__":
    main()
like image 104
KillianDS Avatar answered Sep 19 '22 15:09

KillianDS