Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Overloading the [] operator in python class to refer to a numpy.array data member

I have written a data container class which essentially contains a numpy ndarray member along with methods to generate time_series masks/cross-sectional masks, fetch the date index (row#) in ring-buffer mode, handle resizing keeping in mind that the data may be a ring buffer, and implement restrictions on the shape/dimensions, etc.

As a result of my class implementation, now that I've to access the data wrapped by this object by explicitly referring to the *.data member. This is cumbersome and I'd like to implement the [] operator in my class such that when called on a instance of my class, it refers to the same operation on the underlying ndarray object. How can I achieve this?

def MyArray(object):
    def __init__(self, shape, fill_value, dtype):
        self.shape = shape
        self.fill_value = fill_value
        self.dtype = dtype
        self.data = numpy.empty(shape, fill_value=fill_value, dtype=dtype)

    def reset(self, fill_value=None):
        self.data.fill(fill_value or self.fill_value)

    def resize(self, shape):
        if self.data.ndim != len(shape): raise Exception("dim error")
        if self.data.shape < shape: raise Exception("sizing down not permitted")
        # do resizing

Now, if I'd like to use this container elsewhere, I have to use it as such:

arr = MyArray(shape=(10000,20), fill_value=numpy.nan, dtype='float')
arr.data[::10] = numpy.NAN
msk = numpy.random.randn(10000,20)<.5
arr.data[~msk] = -1.

The fact that I need to explicitly refer to arr.data every time I use this is too cumbersome and error-prone (I'm forgetting the .data suffix in so many places).

Is there any way I can add a few operators such that slicing and indexing on arr actually operates on arr.data implicitly?

like image 537
Mindstorm Avatar asked Nov 23 '15 22:11

Mindstorm


1 Answers

You need to implement the __getitem__ and __setitem__ magic functions.

A complete overview for the magic methods can be found here.

import numpy as np

class MyArray():
    def __init__(self):
        self.data = np.zeros(10)

    def __getitem__(self, key):
        return self.data[key]

    def __setitem__(self, key, value):
        self.data[key] = value

    def __repr__(self):
        return 'MyArray({})'.format(self.data)


a = MyArray()

print(a[9])
print(a[1:5])
a[:] = np.arange(10)
print(a)

Which will give you this result:

0.0
[ 0.  0.  0.  0.]
MyArray([ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9.])

Inheritance

If you just want to modify or add behaviour of np.ndarray, you could inherit from it. This is a little more complicated than for normal python classes, but implementing your case should be not that hard:

import numpy as np


class MyArray(np.ndarray):

    def __new__(cls, shape, fill_value=0, dtype=float):
        data = np.full(shape, fill_value, dtype)
        obj = np.asarray(data).view(cls)
        obj.fill_value = fill_value
        return obj

    def reset(self, fill_value=None):
        if fill_value is not None:
            self.fill_value = fill_value

        self.fill(self.fill_value)

For more info, see here.

like image 149
MaxNoe Avatar answered Oct 11 '22 19:10

MaxNoe