Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Equality Comparison with NumPy Instance Invokes `__bool__`

I have defined a class where its __ge__ method returns an instance of itself, and whose __bool__ method is not allowed to be invoked (similar to a Pandas Series).

Why is X.__bool__ invoked during np.int8(0) <= x, but not for any of the other examples? Who is invoking it? I have read the Data Model docs but I haven’t found my answer there.

import numpy as np
import pandas as pd

class X:
    def __bool__(self):
        print(f"{self}.__bool__")
        assert False
    def __ge__(self, other):
        print(f"{self}.__ge__")
        return X()

x = X()

np.int8(0) <= x

# Console output:
# <__main__.X object at 0x000001BAC70D5C70>.__ge__
# <__main__.X object at 0x000001BAC70D5D90>.__bool__
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "<stdin>", line 4, in __bool__
# AssertionError

0 <= x

# Console output:
# <__main__.X object at 0x000001BAC70D5C70>.__ge__
# <__main__.X object at 0x000001BAC70D5DF0>

x >= np.int8(0)

# Console output:
# <__main__.X object at 0x000001BAC70D5C70>.__ge__
# <__main__.X object at 0x000001BAC70D5D30>


pd_ge = pd.Series.__ge__
def ge_wrapper(self, other):
    print("pd.Series.__ge__")
    return pd_ge(self, other)

pd.Series.__ge__ = ge_wrapper

pd_bool = pd.Series.__bool__
def bool_wrapper(self):
    print("pd.Series.__bool__")
    return pd_bool(self)

pd.Series.__bool__ = bool_wrapper


np.int8(0) <= pd.Series([1,2,3])

# Console output:
# pd.Series.__ge__
# 0    True
# 1    True
# 2    True
# dtype: bool
like image 294
Mike R Avatar asked Nov 06 '22 01:11

Mike R


1 Answers

I suspect that np.int8.__le__ is defined so that instead of returning NotImplemented and letting X.__ge__ take over, it instead tries to return something like not (np.int(8) > x), and then np.int8.__gt__ raises NotImplemented. Once X.__gt__(x, np.int8(0)) returns an instance of X rather than a Boolean value, then we need to call x.__bool__() in order to compute the value of not x.

(Still trying to track down where int8.__gt__ is defined to confirm.)

(Update: not quite. int8 uses a single generic rich comparison function that simply converts the value to a 0-dimensional array, then returns the result of PyObject_RichCompare on the array and x.)


I did find this function that appears to ultimately implement np.int8.__le__:

static NPY_INLINE int
rational_le(rational x, rational y) {
    return !rational_lt(y,x);
}

It's not clear to me how we avoid getting to this function if one of the arguments (like X) would not be a NumPy type. I think I give up.

like image 85
chepner Avatar answered Nov 13 '22 18:11

chepner