Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using __contains__ on a list of custom class objects

Tags:

python

I have a simple class defined like this:

class User(object):
    def __init__(self, id=None, name=None):
        self.id = id
        self.name = name

    def __contains__(self, item):
        return item == self.id

Using this class, I can do simple checks of single instances of the class like so:

>>> user1 = User(1, 'name_1')
>>> 1 in user1
True
>>> 2 in user1
False

This is working as expected.

How can I check if a value is in a list of User objects though? It always seems to return False.

Example:

from random import randint
from pprint import pprint
users = [User(x, 'name_{}'.format(x)) for x in xrange(5)]
pprint(users, indent=4)

for x in xrange(5):
    i = randint(2,6) 
    if i in users:
        print("User already exists: {}".format(i))
    else:
        print("{} is not in users list. Creating new user with id: {}".format(i, i))
        users.append(User(i, 'new_user{}'.format(i)))
pprint(users, indent=4)

This creates output similar to this:

[   0 => name_0, 
    1 => name_1, 
    2 => name_2, 
    3 => name_3, 
    4 => name_4]
6 is not in users list. Creating new user with id: 6
6 is not in users list. Creating new user with id: 6
6 is not in users list. Creating new user with id: 6
3 is not in users list. Creating new user with id: 3
3 is not in users list. Creating new user with id: 3
[   0 => name_0,
    1 => name_1,
    2 => name_2,
    3 => name_3,
    4 => name_4,
    6 => new_user6,
    6 => new_user6,
    6 => new_user6,
    3 => new_user3,
    3 => new_user3]

The issue is that the user with id 6 should have only been created 1 time, because it wasn't created already. The second and third time that 6 was attempted, it should have failed. The user would id 3 shouldn't have been recreated at all because it was part of the initialization of the users variable.

How do I modify by __contains__ method to be able to properly utilize in when comparing against multiple instances of my class?

like image 279
Andy Avatar asked Dec 08 '22 23:12

Andy


2 Answers

If users is a list of users and you check if i in users, then you are not checking User.__contains__. You are checking list.__contains__. Whatever you do in User.__contains__ is not going to affect the result of checking if i is in a list.

If you want to check if i matches any User in users, you could do:

if any(i in u for u in users)

Or a bit more clearly:

if any(u.id==i for u in users)

and avoid using User.__contains__ at all.

like image 132
khelwood Avatar answered Dec 10 '22 11:12

khelwood


This seems like a case where you really wanted to define __eq__ to accept comparisons to both other User objects and int. This would make contains checks for collections of Users work automatically, and would make much more sense in general usage than implementing __contains__ on a non-container type.

import sys
from operator import index

class User(object):  # Explicit inheritance from object can be removed for Py3-only code
    def __init__(self, id=None, name=None):
        self.id = id
        self.name = name

    def __eq__(self, item):
        if isinstance(item, User):
            return self.id == item.id and self.name == item.name
        try:
            # Accept any int-like thing
            return self.id == index(item)
        except TypeError:
            return NotImplemented

    # Canonical mirror of __eq__ only needed on Py2; Py3 defines it implicitly
    if sys.version_info < (3,):
        def __ne__(self, other):
            equal = self.__eq__(other)
            return equal if equal is NotImplemented else not equal

    def __hash__(self):
        return self.id

Now you can use your type with normal collections (including set and dict keys) and it can be looked up easily.

from operator import attrgetter

# Use set for faster lookup; can sort for display when needed
users = {User(x, 'name_{}'.format(x)) for x in xrange(5)}
pprint(sorted(users, key=attrgetter('id')), indent=4)

for x in xrange(5):
    i = randint(2,6) 
    if i in users:
        print("User already exists: {}".format(i))
    else:
        print("{} is not in users list. Creating new user with id: {}".format(i, i))
        users.add(User(i, 'new_user{}'.format(i)))
pprint(sorted(users, key=attrgetter('id')), indent=4)
like image 20
ShadowRanger Avatar answered Dec 10 '22 12:12

ShadowRanger