Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I filter a dictionary with arbitrary length tuples as keys efficiently?

TL;DR

What is the most efficient way to implement a filter function for a dictionary with keys of variable dimensions? The filter should take a tuple of the same dimensions as the dictionary's keys and output all keys in the dictionary which match the filter such that filter[i] is None or filter[i] == key[i] for all dimensions i.


In my current project, I need to handle dictionaries with a lot of data. The general structure of the dictionary is such that it contains tuples with 2 to 4 integers as keys and integers as values. All keys in a dictionary have the same dimensions. To illustrate, the following are examples of dictionaries I need to handle:

{(1, 2): 1, (1, 5): 2}
{(1, 5, 3): 2}
{(5, 2, 5, 2): 8}

These dictionaries contain a lot of entries, with the largest ones at about 20 000 entries. I frequently need to filter these entries, but often only looking at certain indices of the key tuples. Ideally, I want to have a function to which I can supply a filter tuple. The function should then return all keys which match the filter tuple. If the filter tuple contains a None entry, then this will match any value in the dictionary's key tuple at this index.

Example of what the function should do for a dictionary with 2-dimensional keys:

>>> dict = {(1, 2): 1, (1, 5): 2, (2, 5): 1, (3, 9): 5}
>>> my_filter_fn((1, None))
{(1, 2), (1, 5)}
>>> my_filter_fn((None, 5))
{(1, 5), (2, 5)}
>>> my_filter_fn((2, 4))
set()
>>> my_filter_fn((None, None))
{(1, 2), (1, 5), (2, 5), (3, 9)}

As my dictionaries have different dimensions of their tuples, I have tried solving this problem by writing a generator expression which takes the dimensions of the tuple into account:

def my_filter_fn(entries: dict, match: tuple):
    return (x for x in entries.keys() if all(match[i] is None or match[i] == x[i]
                                             for i in range(len(key))))

Unfortunately, this is quite slow compared to writing out condition completely by hand ((match[0] is None or match[0] === x[0]) and (match[1] is None or match[1] == x[1]); for 4 dimensions this is about 10 times slower. This is a problem for me as I need to do this filtering quite often.

Following code demonstrates the performance issue. Code is just supplied to illustrate the problem and enable reproduction of the tests. You can skip the code part, results are below.

import random
import timeit


def access_variable_length():
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
                                                       for i in range(len(key)))):
            pass


def access_static_length():
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if
                  (key[0] is None or x[0] == key[0])
                  and (key[1] is None or x[1] == key[1])
                  and (key[2] is None or x[2] == key[2])
                  and (key[3] is None or x[3] == key[3])):
            pass


def get_rand_or_none(start, stop):
    number = random.randint(start-1, stop)
    if number == start-1:
        number = None
    return number


entry_keys = set()
for h in range(100):
    entry_keys.add((get_rand_or_none(1, 200), get_rand_or_none(1, 10), get_rand_or_none(1, 4), get_rand_or_none(1, 7)))
all_entries = dict()
for l in range(13000):
    all_entries[(random.randint(1, 200), random.randint(1, 10), random.randint(1, 4), random.randint(1, 7))] = 1

variable_time = timeit.timeit("access_variable_length()", "from __main__ import access_variable_length", number=10)
static_time = timeit.timeit("access_static_length()", "from __main__ import access_static_length", number=10)

print("variable length time: {}".format(variable_time))
print("static length time: {}".format(static_time))

Results:

variable length time: 9.625867042849316
static length time: 1.043319165662158

I would like to avoid having to create three different functions my_filter_fn2, my_filter_fn3, and my_filter_fn4 to cover all possible dimensions of my dictionaries and then use static dimensions filtering. I am aware that filtering for variable dimensions will always be slower than filtering for fixed dimensions, but was hoping that it would not be almost 10 times slower. As I am not a Python expert, I was hoping that there is a clever way in which my variable dimensions generator expression could be reformulated to give me better performance.

What is the most efficient way to filter a huge dictionary in the way I described?

like image 382
Chris Avatar asked Jun 01 '17 08:06

Chris


Video Answer


3 Answers

Thanks for the opportunity to think about tuples in sets and dictionaries. It's a very useful and powerful corner of Python.

Python is interpreted, so if you've come from a compiled language, one good rule of thumb is to avoid complex nested iterations where you can. If you're writing complicated for loops or comprehensions it's always worth wondering if there's a better way to do it.

List subscripts (stuff[i]) and range (len(stuff)) are inefficient and long-winded in Python, and rarely necessary. It's more efficient (and more natural) to iterate:

for item in stuff:
    do_something(item)

The following code is fast because it uses some of the strengths of Python: comprehensions, dictionaries, sets and tuple unpacking.

There are iterations, but they're simple and shallow. There's only one if statement in the whole of the code, and that's executed only 4 times per filter operation. That also helps performance-- and makes code easier to read.

An explanation of the method...

Each key from the original data:

{(1, 4, 5): 1}

is indexed by position and value:

{
    (0, 1): (1, 4, 5),
    (1, 4): (1, 4, 5),
    (2, 5): (1, 4, 5)
}

(Python numbers elements from zero.)

Indexes are collated into one big lookup dictionary composed of sets of tuples:

{
    (0, 1): {(1, 4, 5), (1, 6, 7), (1, 2), (1, 8), (1, 4, 2, 8), ...}
    (0, 2): {(2, 1), (2, 2), (2, 4, 1, 8), ...}
    (1, 4): {(1, 4, 5), (1, 4, 2, 8), (2, 4, 1, 8), ...}
    ...
}

Once this lookup is built (and it is built very efficiently) filtering is just set intersection and dictionary lookup, both of which are lightning-fast. Filtering takes microseconds on even a large dictionary.

The method handles data with tuples of arity 2, 3 or 4 (or any other) but arity_filtered() returns only keys with the same number of members as the filter tuple. So this class gives you the option of filtering all data together, or handling the different sizes of tuple separately, with little to choose between them as regards performance.

Timing results for the large random dataset (11,500 tuples) were 0.30s to build the lookup, 0.007 seconds for 100 lookups.

from collections import defaultdict
import random
import timeit


class TupleFilter:
    def __init__(self, data):
        self.data = data
        self.lookup = self.build_lookup()

    def build_lookup(self):
        lookup = defaultdict(set)
        for data_item in self.data:
            for member_ref, data_key in tuple_index(data_item).items():
                lookup[member_ref].add(data_key)
        return lookup

    def filtered(self, tuple_filter):
        # initially unfiltered
        results = self.all_keys()
        # reduce filtered set
        for position, value in enumerate(tuple_filter):
            if value is not None:
                match_or_empty_set = self.lookup.get((position, value), set())
                results = results.intersection(match_or_empty_set)
        return results

    def arity_filtered(self, tuple_filter):
        tf_length = len(tuple_filter)
        return {match for match in self.filtered(tuple_filter) if tf_length == len(match)}

    def all_keys(self):
        return set(self.data.keys())


def tuple_index(item_key):
    member_refs = enumerate(item_key)
    return {(pos, val): item_key for pos, val in member_refs}


data = {
    (1, 2): 1,
    (1, 5): 2,
    (1, 5, 3): 2,
    (5, 2, 5, 2): 8
}

tests = {
     (1, 5): 2,
     (1, None, 3): 1,
     (1, None): 3,
     (None, 5): 2,
}

tf = TupleFilter(data)
for filter_tuple, expected_length in tests.items():
    result = tf.filtered(filter_tuple)
    print("Filter {0} => {1}".format(filter_tuple, result))
    assert len(result) == expected_length
# same arity filtering
filter_tuple = (1, None)
print('Not arity matched: {0} => {1}'
      .format(filter_tuple, tf.filtered(filter_tuple)))
print('Arity matched: {0} => {1}'
      .format(filter_tuple, tf.arity_filtered(filter_tuple)))
# check unfiltered results return original data set
assert tf.filtered((None, None)) == tf.all_keys()


>>> python filter.py
Filter (1, 5) finds {(1, 5), (1, 5, 3)}
Filter (1, None, 3) finds {(1, 5, 3)}
Filter (1, None) finds {(1, 2), (1, 5), (1, 5, 3)}
Filter (None, 5) finds {(1, 5), (1, 5, 3)}
Arity filtering: note two search results only: (1, None) => {(1, 2), (1, 5)}
like image 134
Nick Avatar answered Nov 06 '22 10:11

Nick


I've made some modifications:

  • you don't need to use dict.keys method to iterate through keys, iterating through dict object itself will give us its keys,

  • created separate modules, it helps to read and modify:

    • preparations.py with helpers for generating test data:

      import random
      
      left_ends = [200, 10, 4, 7]
      
      
      def generate_all_entries(count):
          return {tuple(random.randint(1, num)
                        for num in left_ends): 1
                  for _ in range(count)}
      
      
      def generate_entry_keys(count):
          return [tuple(get_rand_or_none(1, num)
                        for num in left_ends)
                  for _ in range(count)]
      
      
      def get_rand_or_none(start, stop):
          number = random.randint(start - 1, stop)
          if number == start - 1:
              number = None
          return number
      
    • functions.py for tested functions,
    • main.py for benchmarks.
  • passing arguments to function instead of getting them from global scope, so given static & variable length versions become

    def access_static_length(all_entries, entry_keys):
        for key in entry_keys:
            for k in (x
                      for x in all_entries
                      if (key[0] is None or x[0] == key[0])
                      and (key[1] is None or x[1] == key[1])
                      and (key[2] is None or x[2] == key[2])
                      and (key[3] is None or x[3] == key[3])):
                pass
    
    
    def access_variable_length(all_entries, entry_keys):
        for key in entry_keys:
            for k in (x
                      for x in all_entries
                      if all(key[i] is None or key[i] == x[i]
                             for i in range(len(key)))):
                pass
    
  • using min on results of timeit.repeat instead of timeit.timeit to get most representable results (more in this answer),

  • changing entries_keys elements count from 10 to 100 (including ends) with step 10,

  • changing all_entries elements count from 10000 to 15000 (including ends) with step 500.


But getting back to the point.

Improvements

  1. We can improve filtration by skipping checks for indexes with None values in keys

    def access_variable_length_with_skipping_none(all_entries, entry_keys):
        for key in entry_keys:
            non_none_indexes = {i
                                for i, value in enumerate(key)
                                if value is not None}
            for k in (x
                      for x in all_entries.keys()
                      if all(key[i] == x[i]
                             for i in non_none_indexes)):
                pass
    
  2. Next suggestion is to use numpy:

    import numpy as np
    
    
    def access_variable_length_numpy(all_entries, entry_keys):
        keys_array = np.array(list(all_entries))
        for entry_key in entry_keys:
            non_none_indexes = [i
                                for i, value in enumerate(entry_key)
                                if value is not None]
            non_none_values = [value
                               for i, value in enumerate(entry_key)
                               if value is not None]
            mask = keys_array[:, non_none_indexes] == non_none_values
            indexes, _ = np.where(mask)
            for k in map(tuple, keys_array[indexes]):
                pass
    

Benchmarks

Contents of main.py:

import timeit
from itertools import product

number = 5
repeat = 10
for all_entries_count, entry_keys_count in product(range(10000, 15001, 500),
                                                   range(10, 101, 10)):
    print('all entries count: {}'.format(all_entries_count))
    print('entry keys count: {}'.format(entry_keys_count))
    preparation_part = ("from preparation import (generate_all_entries,\n"
                        "                         generate_entry_keys)\n"
                        "all_entries = generate_all_entries({all_entries_count})\n"
                        "entry_keys = generate_entry_keys({entry_keys_count})\n"
                        .format(all_entries_count=all_entries_count,
                                entry_keys_count=entry_keys_count))
    static_time = min(timeit.repeat(
        "access_static_length(all_entries, entry_keys)",
        preparation_part + "from functions import access_static_length",
        repeat=repeat,
        number=number))
    variable_time = min(timeit.repeat(
        "access_variable_length(all_entries, entry_keys)",
        preparation_part + "from functions import access_variable_length",
        repeat=repeat,
        number=number))
    variable_time_with_skipping_none = min(timeit.repeat(
        "access_variable_length_with_skipping_none(all_entries, entry_keys)",
        preparation_part +
        "from functions import access_variable_length_with_skipping_none",
        repeat=repeat,
        number=number))
    variable_time_numpy = min(timeit.repeat(
        "access_variable_length_numpy(all_entries, entry_keys)",
        preparation_part +
        "from functions import access_variable_length_numpy",
        repeat=repeat,
        number=number))

    print("static length time: {}".format(static_time))
    print("variable length time: {}".format(variable_time))
    print("variable length time with skipping `None` keys: {}"
          .format(variable_time_with_skipping_none))
    print("variable length time with numpy: {}"
          .format(variable_time_numpy))

which on my machine with Python 3.6.1 gives:

all entries count: 10000
entry keys count: 10
static length time: 0.06314293399918824
variable length time: 0.5234129569980723
variable length time with skipping `None` keys: 0.2890012050011137
variable length time with numpy: 0.22945181500108447
all entries count: 10000
entry keys count: 20
static length time: 0.12795891799760284
variable length time: 1.0610534609986644
variable length time with skipping `None` keys: 0.5744297259989253
variable length time with numpy: 0.5105678180007089
all entries count: 10000
entry keys count: 30
static length time: 0.19210158399801003
variable length time: 1.6491422000035527
variable length time with skipping `None` keys: 0.8566724129996146
variable length time with numpy: 0.7363859869983571
all entries count: 10000
entry keys count: 40
static length time: 0.2561357790000329
variable length time: 2.08878050599742
variable length time with skipping `None` keys: 1.1256247100027394
variable length time with numpy: 1.0066140279996034
all entries count: 10000
entry keys count: 50
static length time: 0.32130833200062625
variable length time: 2.6166040710013476
variable length time with skipping `None` keys: 1.4147321179989376
variable length time with numpy: 1.1700750320014777
all entries count: 10000
entry keys count: 60
static length time: 0.38276188999952865
variable length time: 3.153736616997776
variable length time with skipping `None` keys: 1.7147898039984284
variable length time with numpy: 1.4533947029995034
all entries count: 10000
entry keys count: 70
...
all entries count: 15000
entry keys count: 80
static length time: 0.7141444490007416
variable length time: 6.186657476999244
variable length time with skipping `None` keys: 3.376506028998847
variable length time with numpy: 3.1577993860009883
all entries count: 15000
entry keys count: 90
static length time: 0.8115685330012639
variable length time: 7.14327938399947
variable length time with skipping `None` keys: 3.7462387939995097
variable length time with numpy: 3.6140603050007485
all entries count: 15000
entry keys count: 100
static length time: 0.8950150890013902
variable length time: 7.829741768000531
variable length time with skipping `None` keys: 4.1662235900003
variable length time with numpy: 3.914334102999419

Resume

As we can see numpy version isn't so good as expected and it seems to be not numpy's fault.

If we remove converting filtered array records to tuples with map and just leave

for k in keys_array[indexes]:
    ...

then it will be extremely fast (faster than static length version), so the problem is in conversion from numpy.ndarray objects to tuple.

Filtering out None entry keys gives us about 50% speed gain, so feel free to add it.

like image 38
Azat Ibrakov Avatar answered Nov 06 '22 11:11

Azat Ibrakov


I don't have a beautiful answer, but this sort of optimisation often makes code harder to read. But if you just need more speed here are two things you can do.

Firstly we can straightforwardly eliminate a repeated computation from inside the loop. You say that all the entries in each dictionary have the same length so you can compute that once, rather than repeatedly in the loop. This shaves off about 20% for me:

def access_variable_length():
    try:
        length = len(iter(entry_keys).next())
    except KeyError:
        return
    r = list(range(length))
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
                                                       for i in r)):
            pass

Not pretty, I agree. But we can make it much faster (and even uglier!) by building the fixed length function using eval. Like this:

def access_variable_length_new():
    try:
        length = len(iter(entry_keys).next())
    except KeyError:
        return
    func_l = ["(key[{0}] is None or x[{0}] == key[{0}])".format(i) for i in range(length)]
    func_s = "lambda x,key: " + " and ".join(func_l)
    func = eval(func_s)
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if func(x,key)):
            pass

For me, this is nearly as fast as the static version.

like image 22
strubbly Avatar answered Nov 06 '22 10:11

strubbly