Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I override comparisons between NumPy's ndarray and my type?

In NumPy, it is possible to use the __array_priority__ attribute to take control of binary operators acting on an ndarray and a user-defined type. For instance:

class Foo(object):
  def __radd__(self, lhs): return 0
  __array_priority__ = 100

a = np.random.random((100,100))
b = Foo()
a + b # calls b.__radd__(a) -> 0

The same thing, however, doesn't appear to work for comparison operators. For instance, if I add the following line to Foo, then it is never called from the expression a < b:

def __rlt__(self, lhs): return 0

I realize that __rlt__ is not really a Python special name, but I thought it might work. I tried all of __lt__, __le__, __eq__, __ne__, __ge__, __gt__ with and without a preceding r, plus __cmp__, too, but I could never get NumPy to call any of them.

Can these comparisons be overridden?

UPDATE

To avoid confusion, here is a longer description NumPy's behavior. For starters, here's what the Guide to NumPy book says:

If the ufunc has 2 inputs and 1 output and the second input is an Object array
then a special-case check is performed so that NotImplemented is returned if the
second input is not an ndarray, has the array priority attribute, and has an
r<op> special method.

I think this is the rule that makes + work. Here's an example:

import numpy as np
a = np.random.random((2,2))


class Bar0(object):
  def __add__(self, rhs): return 0
  def __radd__(self, rhs): return 1

b = Bar0()
print a + b # Calls __radd__ four times, returns an array
# [[1 1]
#  [1 1]]



class Bar1(object):
  def __add__(self, rhs): return 0
  def __radd__(self, rhs): return 1
  __array_priority__ = 100

b = Bar1()
print a + b # Calls __radd__ once, returns 1
# 1

As you can see, without __array_priority__, NumPy interprets the user-defined object as a scalar type, and applies the operation at every position in the array. That's not what I want. My type is array-like (but should not be derived from ndarray).

Here's a longer example showing how this fails when all of the comparison methods are defined:

class Foo(object):
  def __cmp__(self, rhs): return 0
  def __lt__(self, rhs): return 1
  def __le__(self, rhs): return 2
  def __eq__(self, rhs): return 3
  def __ne__(self, rhs): return 4
  def __gt__(self, rhs): return 5
  def __ge__(self, rhs): return 6
  __array_priority__ = 100

b = Foo()
print a < b # Calls __cmp__ four times, returns an array
# [[False False]
#  [False False]]
like image 535
AndyJost Avatar asked Jan 31 '13 06:01

AndyJost


People also ask

What is the difference between Ndarray and array?

array is just a convenience function to create an ndarray ; it is not a class itself. You can also create an array using numpy. ndarray , but it is not the recommended way. From the docstring of numpy.

How do I change the Dtype of an NP array?

In order to change the dtype of the given array object, we will use numpy. astype() function. The function takes an argument which is the target data type. The function supports all the generic types and built-in types of data.

What do we call the primary type of array in Python that applies using NumPy?

NumPy's main object is the homogeneous multidimensional array. It is a table of elements (usually numbers), all of the same type, indexed by a tuple of non-negative integers. In NumPy dimensions are called axes.

When using NumPy in Python how do you check the dimensionality?

Alternatively, we can use the shape attribute to get the size of each dimension and then use len() function for the number of dimensions. Use numpy. array() function to convert a list to a NumPy array and use one of the above two ways to get the number of dimensions.


1 Answers

It looks like I can answer this myself. np.set_numeric_ops can be used as follows:

class Foo(object):
  def __lt__(self, rhs): return 0
  def __le__(self, rhs): return 1
  def __eq__(self, rhs): return 2
  def __ne__(self, rhs): return 3
  def __gt__(self, rhs): return 4
  def __ge__(self, rhs): return 5
  __array_priority__ = 100

def override(name):
  def ufunc(x,y):
    if isinstance(y,Foo): return NotImplemented
    return np.getattr(name)(x,y)
  return ufunc

np.set_numeric_ops(
    ** {
        ufunc : override(ufunc) for ufunc in (
            "less", "less_equal", "equal", "not_equal", "greater_equal"
          , "greater"
          )
    }
  )

a = np.random.random((2,2))
b = Foo()
print a < b
# 4
like image 154
AndyJost Avatar answered Oct 23 '22 16:10

AndyJost