Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Design __eq__ that compares __dict__ of self and other safe from RecursionError

I've stumbled upon really weird python 3 issue, cause of which I do not understand.

I'd like to compare my objects by checking if all their attributes are equal.

Some of the child classes will have fields that contain references to methods bound to self - and that causes RecursionError

Here's the PoC:

class A:

    def __init__(self, field):
        self.methods = [self.method]
        self.field = field

    def __eq__(self, other):
        if type(self) != type(other):
            return False
        return self.__dict__ == other.__dict__

    def method(self):
        pass


first = A(field='foo')
second = A(field='bar')

print(first == second)

Running the code above in python 3 raises RecursionError and I'm not sure why. It seems that the A.__eq__ is used to compare the functions kept in self.methods. So my first question is - why? Why the object's __eq__ is called to compare bound function of that object?

The second question is - What kind of filter on __dict__ should I use to protect the __eq__ from this issue? I mean - in the PoC above the self.method is kept simply in a list, but sometimes it may be in another structure. The filtering would have to include all the possible containers that can hold the self-reference.

One clarification: I do need to keep the self.method function in a self.methods field. The usecase here is similar to unittest.TestCase._cleanups - a stack of methods that are to be called after the test is finished. The framework must be able to run the following code:


# obj is a child instance of the A class

obj.append(obj.child_method)

for method in obj.methods:
    method()

Another clarification: the only code I can change is the __eq__ implementation.

like image 717
Mirek Avatar asked Dec 11 '22 03:12

Mirek


1 Answers

"Why the object's __eq__ is called to compare bound function of that object?":

Because bound methods compare by the following algorithm:

  1. Is the self bound to each method equal?
  2. If so, is the function implementing the method the same?

Step 1 causes your infinite recursion; in comparing the __dict__, it eventually ends up comparing the bound methods, and to do so, it has to compare the objects to each other again, and now you're right back where you started, and it continues forever.

The only "solution"s I can come up with off-hand are:

  1. Something like the reprlib.recursive_repr decorator (which would be extremely hacky, since you'd be heuristically determining if you're comparing for bound method related reasons based on whether __eq__ was re-entered), or
  2. A wrapper for any bound methods you store that replaces equality testing of the respective selfs with identity testing.

The wrapper for bound methods isn't terrible at least. You'd basically just make a simple wrapper of the form:

class IdentityComparableMethod:
    __slots__ = '_method',
    def __new__(cls, method):
        # Using __new__ prevents reinitialization, part of immutability contract
        # that justifies defining __hash__
        self = super().__new__(cls)
        self._method = method
        return self

    def __getattr__(self, name):
        '''Attribute access should match bound method's'''
        return getattr(self._method, name)

    def __eq__(self, other):
        '''Comparable to other instances, and normal methods'''
        if not isinstance(other, (IdentityComparableMethod, types.MethodType)):
            return NotImplemented
        return (self.__self__ is other.__self__ and
                self.__func__ is other.__func__)

    def __hash__(self):
        '''Hash identically to the method'''
        return hash(self._method)

    def __call__(self, *args, **kwargs):
        '''Delegate to method'''
        return self._method(*args, **kwargs)

    def __repr__(self):
        return '{0.__class__.__name__}({0._method!r})'.format(self)

then when storing bound methods, wrap them in that class, e.g.:

self.methods = [IdentityComparableMethod(self.method)]

You may want to make methods itself enforce this via additional magic (so it only stores functions or IdentityComparableMethods), but that's the basic idea.

Other answers address more targeted filtering, this is just a way to make that filtering unnecessary.

Performance note: I didn't heavily optimize for performance; __getattr__ is the simplest way of reflecting all the attributes of the underlying method. If you want comparisons to go faster, you can fetch out __self__ during initialization and cache it on self directly to avoid __getattr__ calls, changing the __slots__ and __new__ declaration to:

    __slots__ = '_method', '__self__'
    def __new__(cls, method):
        # Using __new__ prevents reinitialization, part of immutability contract
        # that justifies defining __hash__
        self = super().__new__(cls)
        self._method = method
        self.__self__ = method.__self__
        return self

That makes a pretty significant difference in comparison speed; in local %timeit tests, the first == second comparison dropped from 2.77 μs to 1.05 μs. You could cache __func__ as well if you like, but since it's the fallback comparison, it's less likely to be checked at all (and you'd slow construction a titch for an optimization you're less likely to use).

Alternatively, instead of caching, you can just manually define @propertys for __self__ and __func__, which are slower than raw attributes (comparison ran in 1.41 μs), but incur no construction time cost at all (so if no comparison is ever run, you don't pay the lookup cost).

like image 110
ShadowRanger Avatar answered Dec 22 '22 00:12

ShadowRanger