Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallelize a function call with mpi4py

I want to use mpi4py to parallize an optimization problem. To minimize my function I use the minimize routine from scipy

from scipy.optimize import minimize

def f(x, data) :
    #returns f(x)
x = minimize(f, x0, args=(data))

Now if I want to parallelize my function using mpi4py. The implementatino of the minimization algorithm is sequential and can only run on one process so only my function is parallelized which is not a problem since the function call is to most time consuming step. But I can't figure out how to implement this problem, with parallel and sequential parts.

Here is my attempt:

from scipy.optimize import minimize
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

N = 100 # for testing
step = N//size # say that N is divisible by size
def mpi_f(x, data) :
    x0 = x[rank*step:(rank+1)*step]
    res = f(x0, data)
    res = comm.gather(res, root=0)
    if rank == 0 :
        return res

if rank == 0 :
   x = np.zeros(N)
   xs = minimize(mpi_f, x, args=(data))

This is obviously not working since mpi_f only runs on the process 0. So I am asking how should I proceed ?

Thanks.

like image 235
K.Hassan Avatar asked May 11 '16 10:05

K.Hassan


People also ask

What is mpi4py in Python?

Project description. This package provides Python bindings for the Message Passing Interface (MPI) standard. It is implemented on top of the MPI specification and exposes an API which grounds on the standard MPI-2 C++ bindings.

Does NumPy use MPI?

MPI for Python supports convenient, pickle-based communication of generic Python object as well as fast, near C-speed, direct array data communication of buffer-provider objects (e.g., NumPy arrays). You have to use methods with all-lowercase names, like Comm. send , Comm. recv , Comm.

How do I run an MPI program in Python?

The use of -m mpi4py to execute Python code on the command line resembles that of the Python interpreter. mpiexec -n numprocs python -m mpi4py pyfile [arg] ... mpiexec -n numprocs python -m mpi4py -m mod [arg] ... mpiexec -n numprocs python -m mpi4py -c cmd [arg] ...


1 Answers

In your code, the root process is the only one which calls comm.gather() since the root process is the only one which calls the parallelized cost function. Hence, the program faces a deadlock. You are well aware of this issue.

To overcome this deadlock, the other processes must call the cost function as many times as minimize needs it. Since this number of calls is not known in advance, a while loop seems suitable for these processes.

The stopping condition of the while loop is to be defined. This flag is to be brodcast from the root process to all processes since the root process is the only one aware of the fact that the minimize() function ended. The broadcast must be performed in the cost function, since all processes must test the end of the minimize function at each iteration. Since minimize makes use of the return value of the function, the flag is passed by reference via a mutable type

Finally, here is a potential solution to your problem. It is ran by mpirun -np 4 python main.py. I used fmin() instead of minimize() because my version of scipy is outdated.

#from scipy.optimize import minimize
from scipy.optimize import fmin
from mpi4py import MPI
import numpy as np

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

N = 100 # for testing
step = N//size # say that N is divisible by size

def parallel_function_caller(x,stopp):
    stopp[0]=comm.bcast(stopp[0], root=0)
    summ=0
    if stopp[0]==0:
        #your function here in parallel
        x=comm.bcast(x, root=0)
        array= np.arange(x[0]-N/2.+rank*step-42,x[0]-N/2.+(rank+1)*step-42,1.)
        summl=np.sum(np.square(array))
        summ=comm.reduce(summl,op=MPI.SUM, root=0)
        if rank==0:
            print "value is "+str(summ)
    return summ

if rank == 0 :
   stop=[0]
   x = np.zeros(1)
   x[0]=20
   #xs = minimize(parallel_function_caller, x, args=(stop))
   xs = fmin(parallel_function_caller,x0= x, args=(stop,))
   print "the argmin is "+str(xs)
   stop=[1]
   parallel_function_caller(x,stop)

else :
   stop=[0]
   x=np.zeros(1)
   while stop[0]==0:
      parallel_function_caller(x,stop)
like image 162
francis Avatar answered Oct 06 '22 02:10

francis