Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Custom class with __add__ to add with NumPy array

I have a custom class implementing __add__ and __radd__ as

import numpy

class Foo(object):

    def __init__(self, val):
        self.val = val

    def __add__(self, other):
        print('__add__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return self.val + other

    def __radd__(self, other):
        print('__radd__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return other + self.val

I first test __add__

r1 = Foo(numpy.arange(3)) + numpy.arange(3,6)
print('type results = %s' % type(r1))
print('result = {}'.format(r1))

and it leads to the expected result

>>> __add__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'numpy.ndarray'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [3  5  7]

However, testing __radd__

r2 = numpy.arange(3) + Foo(numpy.arange(3,6))
print('type results = %s' % type(r2))
print('result = {}'.format(r2))

I get

>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [array([3, 4, 5]) array([4, 5, 6]) array([5, 6, 7])]

This doesn't make any sense to me. Does NumPy overload __add__ for arbitrary objects, which then takes priority over my __radd__? If yes, why would they do such a thing? Additionally, how can I avoid this, I really want to be able to add my custom class with a NumPy array to the left. Thanks.

like image 698
Max Avatar asked Aug 29 '17 19:08

Max


People also ask

How do I add values to an array in NumPy?

For this task we can use numpy. append(). This function can help us to append a single value as well as multiple values at the end of the array.

Can you add to a NumPy array?

Append values to the end of an array. Values are appended to a copy of this array. These values are appended to a copy of arr.

How do I add elements to an existing NumPy array?

append() function is used to add or append new values to an existing numpy array. This function adds the new values at the end of the array. The numpy append() function is used to merge two arrays. It returns a new array, and the original array remains unchanged.

What is __ Array_interface __?

__array_interface__ A dictionary of items (3 required and 5 optional). The optional keys in the dictionary have implied defaults if they are not provided. The keys are: shape (required) Tuple whose elements are the array size in each dimension.


1 Answers

This is hidden by the comments, but should be the answer.

By default, Numpy operations work on a per-element basis taking any arbitrary object, and then attempting to perform the operation by element (according to the broadcasting rules).

What this means, for example, is that given

class N:
    def __init__(self, x):
        self.x = x

    def __add__(self, other):
        return self.x + other

    def __radd__(self, other):
        return other + self.x

due to Pythons operator resolution

N(3) + np.array([1, 2, 3])

will get to the above __add__ with N(3) and the the entire array as other once, then perform a regular Numpy addition.

On the other hand

np.array([1, 2, 3]) + N(3)

will successfully enter Numpy's ufuncs (operator in this case) since they take arbitrary objects as the "other", and then attempt to perform successively:

1 + N(3)
2 + N(3)
3 + N(3)

This means that __add__ above will be called 3 times instead of once, once for each element, significantly slowing down operations. To disable this behavior, and make Numpy raise a NotImplementedError when taking an N object thus allowing the RHS overload radd to take over, add the following to the body of your class:

class N:
    ...
    __numpy_ufunc__ = None # Numpy up to 13.0
    __array_ufunc__ = None # Numpy 13.0 and above

If backward compatibility is not an issue only the second is needed.

like image 50
kabanus Avatar answered Sep 23 '22 13:09

kabanus