Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python: How to extend a huge class with minimum lines of code?

Original problem description

The problem arises when I implement some machine learning algorithm with numpy. I want some new class ludmo which works the same as numpy.ndarray, but with a few more properties. For example, with a new property ludmo.foo. I've tried several methods below, but none is satisfactory.

1. Wrapper

First I created a wrapper class for numpy.ndarray, as

import numpy as np

class ludmo(object):
    def __init__(self)
        self.foo = None
        self.data = np.array([])

But when I use some function (in scikit-learn which I cannot modify) to manipulate a list of np.ndarray instance, I have to first extract all data field of each ludmo object and collect them into a list. After that the list is sorted and I lost the correspondence between the data and original ludmo objects.

2. Inheritance

Then I tried to make ludmo a subclass of numpy.ndarray, as

import numpy as np

class ludmo(np.ndarray):
    def __init__(self, shape, dtype=float, buffer=None, offset=0, strides=None, order=None)
        super().__init__(shape, dtype, buffer, offset, strides, order)
        self.foo = None

But another problem arises then: the most common way to create a numpy.ndarray object is numpy.array(some_list), which returns a numpy.ndarray object, and I have to convert it to a ludmo object. But till now I found no good way to do this; simply changing the __class__ attribute will result in an error.

I'm new to Python and numpy, so there must be some elegant way that I don't know. Any advice is appreciated.

It's better if anyone can give an generic solution, which not only applies to the numpy.ndarray class but also all kinds of classes.

like image 908
xyguo Avatar asked Mar 02 '15 07:03

xyguo


People also ask

How do you extend an existing class in Python?

Use the inheritance syntax class DerivedClass(BaseClass) to extend a class. Use the syntax class DerivedClass(BaseClass) to extend BaseClass as DerivedClass , so that DerivedClass inherits BaseClass .

What is the maximum number of methods a class can have Python?

Methods (and in fact all attributes) of a class are stored in a dict . There is no limit to the number of items a dict can contain, save that each key must be unique.

What is __ init __ in Python?

The __init__ method is the Python equivalent of the C++ constructor in an object-oriented approach. The __init__ function is called every time an object is created from a class. The __init__ method lets the class initialize the object's attributes and serves no other purpose. It is only used within classes.

What does super init mean Python?

Understanding Python super() with __init__() methods It is known as a constructor in Object-Oriented terminology. This method when called, allows the class to initialize the attributes of the class. Python super() The super() function allows us to avoid using the base class name explicitly.


2 Answers

As explained in the docs you could add your own methods to np.ndarray doing:

import numpy as np

class Ludmo(np.ndarray): 
    def sumcols(self):
        return self.sum(axis=1)

    def sumrows(self):
        return self.sum(axis=0)

    def randomize(self):
        self[:] = np.random.rand(*self.shape)

and then creating the instances using the np.ndarray.view() method:

a = np.random.rand(4,5).view(Ludmo)

And use the __array_finalize__() method to define new attributes:

def __array_finalize__(self, arr):
    self.foo = 'foo'
like image 109
Saullo G. P. Castro Avatar answered Sep 27 '22 21:09

Saullo G. P. Castro


Since you ask about a generic solution, here's a generic wrapper class that you can use: (from http://code.activestate.com/recipes/577555-object-wrapper-class/ )

class Wrapper(object):
    '''
    Object wrapper class.
    This a wrapper for objects. It is initialiesed with the object to wrap
    and then proxies the unhandled getattribute methods to it.
    Other classes are to inherit from it.
    '''
    def __init__(self, obj):
        '''
        Wrapper constructor.
        @param obj: object to wrap
        '''
        # wrap the object
        self._wrapped_obj = obj

    def __getattr__(self, attr):
        # see if this object has attr
        # NOTE do not use hasattr, it goes into
        # infinite recurrsion
        if attr in self.__dict__:
            # this object has it
            return getattr(self, attr)
        # proxy to the wrapped object
        return getattr(self._wrapped_obj, attr)

the way this works is:

when e.g. skicit would call ludmo.data python actually calls ludmo.__getattribute__('data') if ludmo doesn't have the 'data' attribute, python will call ludmo.__getattr__('data')

by overridding the __getattr__ function you intercept this call, check if your ludmo has the data attribute (again, you could get into recursion otherwise), and send the call to your internal object. So you should have covered every possible call to your internal numpy object.

update: You would also have to implement __setattr__ the same way, or you would get this

>>> class bla(object):
...  def __init__(self):
...   self.a = 1
...  def foo(self):
...   print self.a
...
>>> d = Wrapper(bla())
>>> d.a
1
>>> d.foo()
1
>>> d.a = 2
>>> d.a
2
>>> d.foo()
1

and you probably also want to set a new metaclass that intercepts calls to magic functions of new style classes (for full class see https://github.com/hpcugent/vsc-base/blob/master/lib/vsc/utils/wrapper.py for info see How can I intercept calls to python's "magic" methods in new style classes? ) however, this is only needed if you still want to be able to access x.__name__ or x.__file__ and get the magic attribute from the wrapped class, and not your class.

# create proxies for wrapped object's double-underscore attributes
    class __metaclass__(type):
        def __init__(cls, name, bases, dct):

            def make_proxy(name):
                def proxy(self, *args):
                    return getattr(self._obj, name)
                return proxy

            type.__init__(cls, name, bases, dct)
            if cls.__wraps__:
                ignore = set("__%s__" % n for n in cls.__ignore__.split())
                for name in dir(cls.__wraps__):
                    if name.startswith("__"):
                        if name not in ignore and name not in dct:
                            setattr(cls, name, property(make_proxy(name)))
like image 43
Jens Timmerman Avatar answered Sep 27 '22 21:09

Jens Timmerman