Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Subclass of numpy ndarray doesn't work as expected

Tags:

python

numpy

`Hello, everyone.

I found there is a strange behavior when subclassing a ndarray.

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

As you have seen, the keepdims argument doesn't work for my subclass fooarray. It lost one of its axis. How can't I avoid this problem? Or more generally, how can I subclass numpy ndarray correctly?

like image 815
Ken T Avatar asked May 19 '14 10:05

Ken T


1 Answers

np.sum can accept a variety of objects as input: not only ndarrays, but also lists, generators, np.matrixs, for instance. The keepdims parameter obviously does not make sense for lists or generators. It is also not appropriate for np.matrix instances either, since np.matrixs always have 2 dimensions. If you look at the call signature for np.matrix.sum you see that its sum method has no keepdims parameter:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

So some subclasses of ndarray may have sum methods which do not have a keepdims parameter. This is an unfortunate violation of the Liskov substitution principle and the origin of the pitfall you encountered.

Now if you look at the source code for np.sum, you see that it is a delegating function which tries to determine what to do based on the type of the first argument.

If the type of the first argument is not ndarray, it drops the keepdims parameter. It does this because passing the keepdims parameter to np.matrix.sum would raise an exception.

So because np.sum is trying to do the delegation in the most general way, not making any assumption about what arguments a subclass of ndarray may take, it drops the keepdims parameter when passed a fooarray.

The workaround is to not use np.sum, but call a.sum instead. This is more direct anyway, since np.sum is merely a delegating function.

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2
like image 57
unutbu Avatar answered Oct 18 '22 14:10

unutbu