Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

When subclassing a numpy ndarray, how can I modify __getitem__ properly?

I am attempting to subclass numpy's ndarray. In my subclass, called MyClass, I've added a field called time as a parallel array to the main data.

My goal is the following: suppose I make an instance of MyClass, let's call it mc. I slice mc, for instance mc[2:6], and I want the resulting object to contain not only the properly sliced np array, but also the correspondingly sliced time array.

Here is my attempt:

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
    def __getitem__(self, item):
        #print item #for testing
        ret = super(MyClass, self).__getitem__(item)
        ret.time = self.time.__getitem__(item)
        return ret

This does not work. After many hours of messing around, I realized it is because when I call mc[2:6], __getitem__ is actually called multiple times. First when it is called, the item variable, as expected, is slice(2,6,None). But then, the line containing super(MyClass, self)... calls this same function again, presumably to go retrieve the individual elements of the slice.

The issue is that it supplies __getitem__ with a strange set of parameters, always negative numbers. In the example of mc[2:6], it calls the method 4 more times, with item values of -4, -3, -2, and -1.

As you can see, this makes it impossible for me to properly adjust the ret.time variable, since it attempts to modify it multiple times, often with nonsensical indices.

I have tried working around this in many ways, including copying the object and editing that copy instead, taking various views of the object, and many other hacks, but none seem to overcome this issue that __getitem__ is repeatedly called with negative indices that do not line up with the requested slice.

Any help or explanations as to what is going on would be greatly appreciated.

like image 982
benson Avatar asked Jul 08 '15 02:07

benson


1 Answers

I had a similar problem that I solved using the numpy matrix class as an example. __getitem__ can be called multiple times as you have noticed before the array gets created in __array_finalize__. So the solution is to store the potential new index in __getitem__ but set it in __array_finalize__.

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
        try:
            self.time = self.time[obj._new_time_index]
        except:
            pass

    def __getitem__(self, item):
        try:
            if isinstance(item, (slice, int)):
                self._new_time_index = item
            else:
                self._new_time_index = item[0]
        except: 
            pass
        return super().__getitem__(item)
like image 97
greedybuddha Avatar answered Nov 05 '22 20:11

greedybuddha