Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pybind11: Possible to use mpi4py?

Is it possible in Pybind11 to use mpi4py on the Python side and then to hand over the communicator to the C++ side?

If so, how would it work?

If not, is it possible for example with Boost? And if so, how would it be done?

I searched the web literally for hours but didn't find anything.

like image 823
Quasar Avatar asked Dec 13 '22 17:12

Quasar


1 Answers

Passing an mpi4py communicator to C++ using pybind11 can be done using the mpi4py C-API. The corresponding header files can be located using the following Python code:

import mpi4py
print(mpi4py.get_include())

To convenietly pass communicators between Python and C++, a custom pybind11 type caster can be implemented. For this purpose, we start with the typical preamble.

// native.cpp
#include <pybind11/pybind11.h>
#include <mpi.h>
#include <mpi4py/mpi4py.h>

namespace py = pybind11;

In order for pybind11 to automatically convert a Python type to a C++ type, we need a distinct type that the C++ compiler can recognise. Unfortunately, the MPI standard does not specify the type for MPI_comm. Worse, in common MPI implementations MPI_comm can be defined as int or void* which the C++ compiler cannot distinguish from regular use of these types. To create a distinct type, we define a wrapper class for MPI_Comm which implicitly converts to and from MPI_Comm.

struct mpi4py_comm {
  mpi4py_comm() = default;
  mpi4py_comm(MPI_Comm value) : value(value) {}
  operator MPI_Comm () { return value; }

  MPI_Comm value;
};

The type caster is then implemented as follows:

namespace pybind11 { namespace detail {
  template <> struct type_caster<mpi4py_comm> {
    public:
      PYBIND11_TYPE_CASTER(mpi4py_comm, _("mpi4py_comm"));

      // Python -> C++
      bool load(handle src, bool) {
        PyObject *py_src = src.ptr();

        // Check that we have been passed an mpi4py communicator
        if (PyObject_TypeCheck(py_src, &PyMPIComm_Type)) {
          // Convert to regular MPI communicator
          value.value = *PyMPIComm_Get(py_src);
        } else {
          return false;
        }

        return !PyErr_Occurred();
      }

      // C++ -> Python
      static handle cast(mpi4py_comm src,
                         return_value_policy /* policy */,
                         handle /* parent */)
      {
        // Create an mpi4py handle
        return PyMPIComm_New(src.value);
      }
  };
}} // namespace pybind11::detail

Below is the code of an example module which uses the type caster. Note, that we use mpi4py_comm instead of MPI_Comm in the function definitions exposed to pybind11. However, due to the implicit conversion, we can use these variables as regular MPI_Comm variables. Especially, they can be passed to any function expecting an argument of type MPI_Comm.

// recieve a communicator and check if it equals MPI_COMM_WORLD
void print_comm(mpi4py_comm comm)
{
  if (comm == MPI_COMM_WORLD) {
    std::cout << "Received the world." << std::endl;
  } else {
    std::cout << "Received something else." << std::endl;
  }
}

mpi4py_comm get_comm()
{
  return MPI_COMM_WORLD; // Just return MPI_COMM_WORLD for demonstration
}

PYBIND11_MODULE(native, m)
{
  // import the mpi4py API
  if (import_mpi4py() < 0) {
    throw std::runtime_error("Could not load mpi4py API.");
  }

  // register the test functions
  m.def("print_comm", &print_comm, "Do something with the mpi4py communicator.");
  m.def("get_comm", &get_comm, "Return some communicator.");
}

The module can be compiled, e.g., using

mpicxx -O3 -Wall -shared -std=c++14 -fPIC \
  $(python3 -m pybind11 --includes) \
  -I$(python3 -c 'import mpi4py; print(mpi4py.get_include())') \
  native.cpp -o native$(python3-config --extension-suffix)

and tested using

import native
from mpi4py import MPI
import math

native.print_comm(MPI.COMM_WORLD)

# Create a cart communicator for testing
# (MPI_COMM_WORLD.size has to be a square number)
d = math.sqrt(MPI.COMM_WORLD.size)
cart_comm = MPI.COMM_WORLD.Create_cart([d,d], [1,1], False)
native.print_comm(cart_comm)

print(f'native.get_comm() == MPI.COMM_WORLD '
      f'-> {native.get_comm() == MPI.COMM_WORLD}')

The output should be:

Received the world.
Received something else.
native.get_comm() == MPI.COMM_WORLD -> True
like image 129
H. Rittich Avatar answered Apr 01 '23 22:04

H. Rittich