Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Equality of Python classes using slots

Another question provides a nice, simple solution for implementing a test for equality of objects. I'll repeat the answer for context:

class CommonEqualityMixin(object):

    def __eq__(self, other):
        return (isinstance(other, self.__class__)
            and self.__dict__ == other.__dict__)

    def __ne__(self, other):
        return not self.__eq__(other)

class Foo(CommonEqualityMixin):

    def __init__(self, item):
        self.item = item

I would like to do this for a class that uses __slots__. I understand that both the base class and the subclass will have to use slots, but how would you define __eq__ for this to work with slots?

like image 942
gaefan Avatar asked Dec 23 '10 21:12

gaefan


People also ask

What is __ slots __ In Python class?

__slots__ is a class variable. If you have more than one instance of your class, any change made to __slots__ will show up in every instance. You cannot access the memory allocated by the __slots__ declaration by using subscription. You will get only what is currently stored in the list.

How slots work Python?

Slots in Python is a special mechanism that is used to reduce memory of the objects. In Python, all the objects use a dynamic dictionary for adding an attribute. Slots is a static type method in this no dynamic dictionary are required for allocating attribute.

Does Python in use __ eq __?

Python automatically calls the __eq__ method of a class when you use the == operator to compare the instances of the class. By default, Python uses the is operator if you don't provide a specific implementation for the __eq__ method.

How do you compare two methods in Python?

Summary. - To compare two objects we need to implement the built in __comparison__ operators. - If we implement the < and <= methods or the > or >= or equal methods we get the others for free, since Python just flips the symbols if it can't find the method we called.


1 Answers

import operator

class CommonEqualityMixin(object):

    __slots__ = ()

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            if self.__slots__ == other.__slots__:
                 attr_getters = [operator.attrgetter(attr) for attr in self.__slots__]
                 return all(getter(self) == getter(other) for getter in attr_getters)

        return False

    def __ne__(self, other):
        return not self.__eq__(other)

An example of usage:

class Foo(CommonEqualityMixin):
    __slots__ = ('a', )
    def __init__(self, a):
        self.a = a

Foo(1) == Foo(2)
# False
Foo(1) == Foo(1)
# True

N.B: be aware thow the __slots__ don't get inherited it's not like __dict__ so if for example a new class FooBar inherit from Foo the code above will not work

Example :

class FooBar(Foo):
    __slots__ = ('z')
    def __init__(self, a, z):
        self.z = z
        super(FooBar, self).__init__(a)

FooBar(1, 1) == FooBar(2, 1)
# True

print FooBar(1, 1).__slots__
# 'z'
like image 50
mouad Avatar answered Sep 24 '22 00:09

mouad