Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can I override a C++ virtual function within Python with Cython?

Tags:

c++

python

cython

I have a C++ class with a virtual method:

//C++
class A
{

    public:
        A() {};
        virtual int override_me(int a) {return 2*a;};
        int calculate(int a) { return this->override_me(a) ;}

};

What I would like to do is to expose this class to Python with Cython, inherit from this class in Python and have the correct overridden called:

#python:
class B(PyA):
   def override_me(self, a):
       return 5*a
b = B()
b.calculate(1)  # should return 5 instead of 2

Is there a way to do this ? Now I'm thinking, it could also be great if we could override the virtual method in Cython as well (in a pyx file), but allowing users to do this in pure python is more important.

Edit: If this helps, a solution could be to use the pseudocode given here: http://docs.cython.org/src/userguide/pyrex_differences.html#cpdef-functions

But there are two problems then :

  • I don't know how to write this pseudocode in Cython
  • maybe there is a better approach
like image 970
ascobol Avatar asked Apr 12 '12 15:04

ascobol


1 Answers

The solution is somewhat complicated, but it is possible. There is a fully working example here: https://bitbucket.org/chadrik/cy-cxxfwk/overview

Here is an overview of the technique:

Create a specialized subclass of class A whose purpose will be to interact with a cython extension:

// created by cython when providing 'public api' keywords:
#include "mycymodule_api.h"

class CyABase : public A
{
public:
  PyObject *m_obj;

  CyABase(PyObject *obj);
  virtual ~CyABase();
  virtual int override_me(int a);
};

The constructor takes a python object, which is the instance of our cython extension:

CyABase::CyABase(PyObject *obj) :
  m_obj(obj)
{
  // provided by "mycymodule_api.h"
  if (import_mycymodule()) {
  } else {
    Py_XINCREF(this->m_obj);
  }
}

CyABase::~CyABase()
{
  Py_XDECREF(this->m_obj);
}

Create an extension of this subclass in cython, implementing all non-virtual methods in the standard fashion

cdef class A:
    cdef CyABase* thisptr
    def __init__(self):
        self.thisptr = new CyABase(
            <cpy_ref.PyObject*>self)

    #------- non-virutal methods --------
    def calculate(self):
        return self.thisptr.calculate()

Create virtual and pure virtual methods as public api functions, that take as arguments the extension instance, the method arguments, and an error pointer:

cdef public api int cy_call_override_me(object self, int a, int *error):
    try:
        func = self.override_me
    except AttributeError:
        error[0] = 1
        # not sure what to do about return value here...
    else:
        error[0] = 0
        return func(a)

Utilize these function in your c++ intermediate like this:

int
CyABase::override_me(int a)
{
  if (this->m_obj) {
    int error;
    // call a virtual overload, if it exists
    int result = cy_call_override_me(this->m_obj, a, &error);
    if (error)
      // call parent method
      result = A::override_me(i);
    return result;
  }
  // throw error?
  return 0;
}

I quickly adapted my code to your example, so there could be mistakes. Take a look at the full example in the repository and it should answer most of your questions. Feel free to fork it and add your own experiments, it's far from complete!

like image 133
chadrik Avatar answered Oct 21 '22 15:10

chadrik