Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cython/Python/C++ - Inheritance: Passing Derived Class as Argument to Function expecting base class

I am using Cython to wrap a set of C++ classes, allowing a Python interface to them. Example Code is provided below:

BaseClass.h:

#ifndef __BaseClass__
#define __BaseClass__
#include <stdio.h>
#include <stdlib.h>
#include <string>
using namespace std;
class BaseClass
{
    public:
        BaseClass(){};
        virtual ~BaseClass(){};
        virtual void SetName(string name){printf("in base set name\n");}
        virtual float Evaluate(float time){printf("in base Evaluate\n");return 0;}
        virtual bool DataExists(){printf("in base data exists\n");return false;}
};
#endif /* defined(__BaseClass__) */

DerivedClass.h:

#ifndef __DerivedClass__
#define __DerivedClass__

#include "BaseClass.h"

class DerivedClass:public BaseClass
{
    public:
        DerivedClass(){};
        virtual ~DerivedClass(){};
        virtual float Evaluate(float time){printf("in derived Evaluate\n");return 1;}
        virtual bool DataExists(){printf("in derived data exists\n");return true;}
        virtual void MyFunction(){printf("in my function\n");}
        virtual void SetObject(BaseClass *input){printf("in set object\n");}
};
#endif /* defined(__DerivedClass__) */

NextDerivedClass.h:

#ifndef __NextDerivedClass__
#define __NextDerivedClass__

#include "DerivedClass.h"

class NextDerivedClass:public DerivedClass
{
    public:
        NextDerivedClass(){};
        virtual ~NextDerivedClass(){};
        virtual void SetObject(BaseClass *input){printf("in set object of next derived class\n");}
};
#endif /* defined(__NextDerivedClass__) */

inheritTest.pyx:

cdef extern from "BaseClass.h":
cdef cppclass BaseClass:
    BaseClass() except +
    void SetName(string)
    float Evaluate(float)
    bool DataExists()

cdef extern from "DerivedClass.h":
    cdef cppclass DerivedClass(BaseClass):
        DerivedClass() except +
        void MyFunction()
        float Evaluate(float)
        bool DataExists()
        void SetObject(BaseClass *)

cdef extern from "NextDerivedClass.h":
    cdef cppclass NextDerivedClass(DerivedClass):
        NextDerivedClass() except +
        # ***  The issue is right here ***
        void SetObject(BaseClass *)

cdef class PyBaseClass:
    cdef BaseClass *thisptr
    def __cinit__(self):
        if type(self) is PyBaseClass:
            self.thisptr = new BaseClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.thisptr

cdef class PyDerivedClass(PyBaseClass):
    cdef DerivedClass *derivedptr
    def __cinit__(self):
        self.derivedptr = self.thisptr = new DerivedClass()
    def __dealloc__(self):
        del self.derivedptr
    # def Evaluate(self, time):
    #     return self.derivedptr.Evaluate(time)
    def SetObject(self, PyBaseClass inputObject):
         self.derivedptr.SetObject(<BaseClass *>inputObject.thisptr)

cdef class PyNextDerivedClass(PyDerivedClass):
    cdef NextDerivedClass *nextDerivedptr
    def __cinit__(self):
        self.nextDerivedptr = self.thisptr = new NextDerivedClass()
    def __dealloc__(self):
        del self.nextDerivedptr
    def SetObject(self, PyBaseClass input):
        self.nextDerivedptr.SetObject(<BaseClass *>input.thisptr)

I want to be able to call SetObject in Python similar to as shown below:

main.py:

from inheritTest import PyBaseClass as base
from inheritTest import PyDerivedClass as der
from inheritTest import PyNextDerivedClass as nextDer

#This works now!
a = der()
b = der()
a.SetObject(b)

#This doesn't work -- keeping the function declaration causes a overloaded error, not keeping it means the call below works, but it calls the inherited implementation (From derived class)
c = nextDer()
c.SetObject(b)

I thought it would work since the classes inherit from each other, but its giving me the following error:

Argument has incorrect type: expected PyBaseClass, got PyDerivedClass

Not specifying type in the function definition makes it think that the inputObject is a pure Python object (has no C-based attributes, which it does), in which case the error is:

*Cannot convert Python object to BaseClass *

A sort-of hacky workaround to this just to have Python functions with different names that expect different types of arguments (ex: SetObjectWithBase, SetObjectWithDerived), and then within their implementation, just call the same C-based function having type-casted the input. I know for a fact this works, but I would like to avoid having to do this as much as possible. Even if there is a way I can catch the Type Error within the function, and deal with it inside, I think that might work, but I wasn't sure exactly how to implement that.

Hope this question makes sense, let me know if you require additional information.

****EDIT****: Code has been edited such that basic inheritance works. After playing around with it a bit more, I realize that the problem is occurring for multiple levels of inheritance, for example, see edited code above. Basically, keeping the declaration for SetObject for the NextDerivedClass causes a "Ambiguous Overloaded Method" error, not keeping it allows me to call the function on the object, but it calls the inherited implementation (from DerivedClass). **

like image 410
jeet.m Avatar asked Feb 17 '15 23:02

jeet.m


2 Answers

After a lot of help from the answers below, and experimentation, I think I understand how implementing basic inheritance within Cython works, I'm answering my own question to validate/improve my understanding, as well as hopefully help out anyone who in the future may encounter a related issue. If there is anything wrong with this explanation, feel free to correct me in the comments below, and I will edit it. I don't think this is the only way to do it, so I'm sure alternate methods work, but this is the way that worked for me.

Overview/Things Learnt:

So basically, from my understanding, Cython is smart enough (given the appropriate information) to traverse through the inheritance hiearchy/tree and call the appropriate implementation of a virtual function based on the type of the object that you are calling it on.

The important thing is to try and mirror the C++ inheritance structure which you are trying to wrap in your .pyx file. This means that ensuring:

1) Imported C++/Cython cppclasses (the ones which are declared as cdef extern from) inherit each other the same way the actual C++ classes do

2) Only unique methods/member variables are declared for each imported class (should not have a function declaration for both BaseClass and DerivedClass for a virtual function that is implemented differently in the two classes). As long as one inherits from the other, the function declaration only needs to be in the Base imported class.

3) Python wrapper classes (ie. PyBaseClass / PyDerivedClass) should also inherit from each other the same way the actual C++ classes do

4) Similar to above, the interface to a virtual function only needs to exist in the PyBase wrapper class (should not be putting in both classes, the correct implementation will be called when you actually run the code).

5) For each Python wrapper class that is subclassed or inherited from, you need a if type(self) is class-name: check in both the __cinit__() and the __dealloc__() functions. This will prevent seg-faults etc. You don't need this check for "leaf-nodes" in the hiearchy tree (classes which won't be inherited from or subclassed)

6) Make sure that in the __dealloc__() function, you only delete the current pointer (and not any inherited ones)

7) Again, in the __cinit__(), for inherited classes make sure you set the current pointer, as well as all derived pointers to an object of the type you are trying to create (ie. *self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()*)

Hopefully the above points make a lot of sense when you see the code below, this compiles and runs/works as I need/intend it to work.

BaseClass.h:

#ifndef __BaseClass__
#define __BaseClass__

#include <stdio.h>
#include <stdlib.h>
#include <string>

using namespace std;

class BaseClass
{
    public:
        BaseClass(){};
        virtual ~BaseClass(){};
        virtual void SetName(string name){printf("BASE: in set name\n");}
        virtual float Evaluate(float time){printf("BASE: in Evaluate\n");return 0;}
        virtual bool DataExists(){printf("BASE: in data exists\n");return false;}
};
#endif /* defined(__BaseClass__) */ 

DerivedClass.h:

#ifndef __DerivedClass__
#define __DerivedClass__

#include "BaseClass.h"
#include "string.h"

using namespace std;

class DerivedClass:public BaseClass
{
    public:
        DerivedClass(){};
        virtual ~DerivedClass(){};
        virtual void SetName(string name){printf("DERIVED CLASS: in Set name \n");}
        virtual float Evaluate(float time){printf("DERIVED CLASS: in Evaluate\n");return 1.0;}
        virtual bool DataExists(){printf("DERIVED CLASS:in data exists\n");return true;}
        virtual void MyFunction(){printf("DERIVED CLASS: in my function\n");}
        virtual void SetObject(BaseClass *input){printf("DERIVED CLASS: in set object\n");}
};
#endif /* defined(__DerivedClass__) */

NextDerivedClass.h:

    #ifndef __NextDerivedClass__
    #define __NextDerivedClass__

    #include "DerivedClass.h"

    class NextDerivedClass:public DerivedClass
    {
        public:
            NextDerivedClass(){};
            virtual ~NextDerivedClass(){};
            virtual void SetObject(BaseClass *input){printf("NEXT DERIVED CLASS: in set object\n");}
            virtual bool DataExists(){printf("NEXT DERIVED CLASS: in data exists \n");return true;}
    };
    #endif /* defined(__NextDerivedClass__) */

inheritTest.pyx:

#Necessary Compilation Options
#distutils: language = c++
#distutils: extra_compile_args = ["-std=c++11", "-g"]

#Import necessary modules
from libcpp cimport bool
from libcpp.string cimport string
from libcpp.map cimport map
from libcpp.pair cimport pair
from libcpp.vector cimport vector

cdef extern from "BaseClass.h":
    cdef cppclass BaseClass:
        BaseClass() except +
        void SetName(string)
        float Evaluate(float)
        bool DataExists()

cdef extern from "DerivedClass.h":
    cdef cppclass DerivedClass(BaseClass):
        DerivedClass() except +
        void MyFunction()
        void SetObject(BaseClass *)

cdef extern from "NextDerivedClass.h":
    cdef cppclass NextDerivedClass(DerivedClass):
        NextDerivedClass() except +

cdef class PyBaseClass:
    cdef BaseClass *thisptr
    def __cinit__(self):
        if type(self) is PyBaseClass:
            self.thisptr = new BaseClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.thisptr
    def SetName(self, name):
        self.thisptr.SetName(name)
    def Evaluate(self, time):
        return self.thisptr.Evaluate(time)
    def DataExists(self):
        return self.thisptr.DataExists()

cdef class PyDerivedClass(PyBaseClass):
    cdef DerivedClass *derivedptr
    def __cinit__(self):
        if type(self) is PyDerivedClass:
            self.derivedptr = self.thisptr = new DerivedClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.derivedptr
    def SetObject(self, PyBaseClass inputObject):
        self.derivedptr.SetObject(<BaseClass *>inputObject.thisptr)
    def MyFunction(self):
        self.derivedptr.MyFunction()

cdef class PyNextDerivedClass(PyDerivedClass):
    cdef NextDerivedClass *nextDerivedptr
    def __cinit__(self):
        self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()
    def __dealloc__(self):
        del self.nextDerivedptr

test.py:

from inheritTest import PyBaseClass as base
from inheritTest import PyDerivedClass as der
from inheritTest import PyNextDerivedClass as nextDer

a = der()
b = der()
a.SetObject(b)
c = nextDer()
a.SetObject(c)
c.DataExists()
c.SetObject(b)
c.Evaluate(0.3)


baseSig = base()
signal = der()
baseSig.SetName('test')
signal.SetName('testingone')
baseSig.Evaluate(0.3)
signal.Evaluate(0.5)
signal.SetObject(b)
baseSig.DataExists()
signal.DataExists()

Notice that when I call:

c = nextDer()
c.Evaluate(0.3)

The way it works is Cython goes down the inheritance tree to look for the "latest" implementation of Evaluate. If it existed in NextDerivedClass.h, it would call that (I have tried that and it works), since it's not there however, it goes one step up and checks DerivedClass. The function is implemented there, thus the output is:

>> DERIVED CLASS: in Evaluate

I hope this helps someone in the future, again, if there are errors in my understanding or just grammar/syntax, feel free to comment below and I will try and address them. Again, big thanks to those who answered below, this is sort of a summary of their answers, just to help validate my understanding. Thanks!

like image 173
jeet.m Avatar answered Nov 05 '22 11:11

jeet.m


Your code, as written, doesn't compile. I suspect that your real PyDerivedClass doesn't really derive from PyBaseClass as if it did that last line would have to be

(<DerivedClass*>self.thisptr).SetObject(inputObject.thisptr)

This would also explain the type error you're getting, which is a bug I can't reproduce.

like image 1
robertwb Avatar answered Nov 05 '22 10:11

robertwb