Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Suppress key addition in collections.defaultdict

When a missing key is queried in a defaultdict object, the key is automatically added to the dictionary:

from collections import defaultdict

d = defaultdict(int)
res = d[5]

print(d)
# defaultdict(<class 'int'>, {5: 0})
# we want this dictionary to remain empty

However, often we want to only add keys when they are assigned explicitly or implicitly:

d[8] = 1  # we want this key added
d[3] += 1 # we want this key added

One use case is simple counting, to avoid the higher overhead of collections.Counter, but this feature may also be desirable generally.


Counter example [pardon the pun]

This is the functionality I want:

from collections import Counter
c = Counter()
res = c[5]  # 0
print(c)  # Counter()

c[8] = 1  # key added successfully
c[3] += 1 # key added successfully

But Counter is significantly slower than defaultdict(int). I find the performance hit usually ~2x slower vs defaultdict(int).

In addition, obviously Counter is only comparable to int argument in defaultdict, while defaultdict can take list, set, etc.


Is there a way to implement the above behaviour efficiently; for instance, by subclassing defaultdict?


Benchmarking example

%timeit DwD(lst)           # 72 ms
%timeit dd(lst)            # 44 ms
%timeit counter_func(lst)  # 98 ms
%timeit af(lst)            # 72 ms

Test code:

import numpy as np
from collections import defaultdict, Counter, UserDict

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        _sentinel = object()
        value = self.get(key, _sentinel)

        if value is _sentinel:
            return self.default_factory()
        return value

class DictWithDefaults(dict):
    __slots__ = ['_factory']  # avoid using extra memory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

lst = np.random.randint(0, 10, 100000)

def DwD(lst):
    d = DictWithDefaults(int)
    for i in lst:
        d[i] += 1
    return d

def dd(lst):
    d = defaultdict(int)
    for i in lst:
        d[i] += 1
    return d

def counter_func(lst):
    d = Counter()
    for i in lst:
        d[i] += 1
    return d

def af(lst):
    d = DefaultDict(int)
    for i in lst:
        d[i] += 1
    return d

Note Regarding Bounty Comment:

@Aran-Fey's solution has been updated since Bounty was offered, so please disregard the Bounty comment.

like image 457
jpp Avatar asked Apr 11 '18 15:04

jpp


2 Answers

Rather than messing about with collections.defaultdict to make it do what we want, it seems easier to implement our own:

class DefaultDict(dict):
    def __init__(self, default_factory, **kwargs):
        super().__init__(**kwargs)

        self.default_factory = default_factory

    def __getitem__(self, key):
        try:
            return super().__getitem__(key)
        except KeyError:
            return self.default_factory()

This works the way you want:

d = DefaultDict(int)

res = d[5]
d[8] = 1 
d[3] += 1

print(d)  # {8: 1, 3: 1}

However, it can behave unexpectedly for mutable types:

d = DefaultDict(list)
d[5].append('foobar')

print(d)  # output: {}

This is probably the reason why defaultdict remembers the value when a nonexistant key is accessed.


Another option is to extend defaultdict and add a new method that looks up a value without remembering it:

from collections import defaultdict

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        return self.get(key, self.default_factory())

Note that the get_and_forget method calls the default_factory() every time, regardless of whether the key already exists in the dict or not. If this is undesirable, you can implement it with a sentinel value instead:

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        _sentinel = object()
        value = self.get(key, _sentinel)

        if value is _sentinel:
            return self.default_factory()
        return value

This has better support for mutable types, because it allows you to choose whether the value should be added to the dict or not.

like image 168
Aran-Fey Avatar answered Oct 10 '22 10:10

Aran-Fey


If you just want a dict that returns a default value when you access a non-existing key then you could simply subclass dict and implement __missing__:

object.__missing__(self, key)

Called by dict.__getitem__() to implement self[key] for dict subclasses when key is not in the dictionary.

That would look like this:

class DictWithDefaults(dict):
    # not necessary, just a memory optimization
    __slots__ = ['_factory']  

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

In this case I used a defaultdict-like approach so you have to pass in a factory that should provide the default value when called:

>>> dwd = DictWithDefaults(int)
>>> dwd[0]  # key does not exist
0 
>>> dwd     # key still doesn't exist
{}
>>> dwd[0] = 10
>>> dwd
{0: 10}

When you do assignments (explicitly or implicitly) the value will be added to the dictionary:

>>> dwd = DictWithDefaults(int)
>>> dwd[0] += 1
>>> dwd
{0: 1}

>>> dwd = DictWithDefaults(list)
>>> dwd[0] += [1]
>>> dwd
{0: [1]}

You wondered how collections.Counter is doing it and as of CPython 3.6.5 it also uses __missing__:

class Counter(dict):
    ...

    def __missing__(self, key):
        'The count of elements not in the Counter is zero.'
        # Needed so that self[missing_item] does not raise KeyError
        return 0

    ...

Better performance?!

You mentioned that speed is of concern, so you could make that class a C extension class (assuming you use CPython), for example using Cython (I'm using the Jupyter magic commands to create the extension class):

%load_ext cython

%%cython

cdef class DictWithDefaultsCython(dict):
    cdef object _factory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

Benchmark

Based on your benchmark:

from collections import Counter, defaultdict

def d_py(lst):
    d = DictWithDefaults(int)
    for i in lst:
        d[i] += 1
    return d

def d_cy(lst):
    d = DictWithDefaultsCython(int)
    for i in lst:
        d[i] += 1
    return d

def d_dd(lst):
    d = defaultdict(int)
    for i in lst:
        d[i] += 1
    return d

Given that this is just counting it would be an (unforgivable) oversight to not include a benchmark simply using the Counter initializer.

I have recently written a small benchmarking tool that I think might come in handy here (but you could do it using %timeit as well):

from simple_benchmark import benchmark
import random

sizes = [2**i for i in range(2, 20)]
unique_lists = {i: list(range(i)) for i in sizes}
identical_lists = {i: [0]*i for i in sizes}
mixed = {i: [random.randint(0, i // 2) for j in range(i)]  for i in sizes}

functions = [d_py, d_cy, d_dd, d_c, Counter]

b_unique = benchmark(functions, unique_lists, 'list size')
b_identical = benchmark(functions, identical_lists, 'list size')
b_mixed = benchmark(functions, mixed, 'list size')

With this result:

import matplotlib.pyplot as plt

f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
ax1.set_title('unique elements')
ax2.set_title('identical elements')
ax3.set_title('mixed elements')
b_unique.plot(ax=ax1)
b_identical.plot(ax=ax2)
b_mixed.plot(ax=ax3)

Note that it uses log-log scale for better visibility of differences:

enter image description here

For long iterables the Counter(iterable) was by far the fastest. DictWithDefaultCython and defaultdict were equal (with DictWithDefault being slightly faster most of the times, even if that's not visible here) followed by DictWithDefault and then Counter with the manual for-loop. Funny how Counter is fastest and slowest.

Implicitly add the returned value if it is modifie

Something I glossed over is the fact that it differs considerably from defaultdict because of the desired "just return the default don't save it" with mutable types:

>>> from collections import defaultdict
>>> dd = defaultdict(list)
>>> dd[0].append(10)
>>> dd
defaultdict(list, {0: [10]})

>>> dwd = DictWithDefaults(list)
>>> dwd[0].append(10)
>>> dwd
{}

That means you actually need to set the element when you want the modified value to be visible in the dictionary.

However this somewhat intrigued me so I want to share a way how you could make that work (if desired). But it's just a quick test and only works for append calls using a proxy. Please don't use that in production code (from my point of view this just has entertainment value):

from wrapt import ObjectProxy

class DictWithDefaultsFunky(dict):
    __slots__ = ['_factory']  # avoid using extra memory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        ret = self._factory()
        dict_ = self

        class AppendTrigger(ObjectProxy):
            def append(self, val):
                self.__wrapped__.append(val)
                dict_[key] = ret

        return AppendTrigger(ret)

That's a dictionary that returns a proxy object (instead of the real default) and it overloads a method that, if called, adds the return value to the dictionary. And it "works":

>>> d = DictWithDefaultsFunky(list)
>>> a = d[10]
>>> d
[]

>>> a.append(1)
>>> d
{10: [1]}

But it does have a few pitfalls (that could be solved but it's just a proof-of-concept so I won't attempt it here):

>>> d = DictWithDefaultsFunky(list)
>>> a = d[10]
>>> b = d[10]
>>> d
{}
>>> a.append(1)
>>> d
{10: [1]}
>>> b.append(10)
>>> d  # oups, that overwrote the previous stored value ...
{10: [10]}

If you really want something like that you probably need to implement a class that really tracks changes within the values (and not just append calls).

If you want to avoid implicit assignments

In case you don't like the fact that += or similar operations add the value to the dictionary (opposed to the previous example which even tried to add the value in a very implicit fashion) then you probably should implement it as method instead of as special method.

For example:

class SpecialDict(dict):
    __slots__ = ['_factory']
    
    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        
    def get_or_default_from_factory(self, key):
        try:
            return self[key]
        except KeyError:
            return self._factory()
        
>>> sd = SpecialDict(int)
>>> sd.get_or_default_from_factory(0)  
0
>>> sd  
{}
>>> sd[0] = sd.get_or_default_from_factory(0) + 1
>>> sd  
{0: 1}

Which is similar to the behavior of Aran-Feys answer but instead of get with a sentinel it uses a try and catch approach.

like image 8
MSeifert Avatar answered Oct 10 '22 10:10

MSeifert