Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to avoid using _siftup or _siftdown in heapq

Tags:

python

heap

I have no idea how to solve following problem efficiently without using _siftup or _siftdown:

How to restore the heap invariant, when one element is out-of-order?

In other words, update old_value in heap to new_value, and keep heap working. you can assume there is only one old_value in heap. The fucntion definition is like:

def update_value_in_heap(heap, old_value, new_value):

Here is my real scenario, read it if you are interested in.

  • You can imagine it is a small autocomplete system. I need to count the frequency of words, and maintain the top k max-count words, which prepare to output at any moment. So I use heap here. When one word count++, I need update it if it is in heap.

  • All the words and counts are stored in trie-tree's leaf, and heaps
    are stored in trie-tree's middle nodes. If you care about the word
    out of heap, don't worry, I can get it from trie-tree's leaf node.

  • when user type a word, it will first read from heap and then update
    it. For better performance, we can consider decrease update frequency by updated in batch.

So how to update the heap, when one particular word count increase?

Here is _siftup or _siftdown version simple example(not my scenario):

>>> from heapq import _siftup, _siftdown, heapify, heappop

>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 22              # increase the 8 to 22
>>> i = data.index(old)
>>> data[i] = new
>>> _siftup(data, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 5, 7, 10, 18, 19, 22, 37]

>>> data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]
>>> heapify(data)
>>> old, new = 8, 4              # decrease the 8 to 4
>>> i = data.index(old)
>>> data[i] = new
>>> _siftdown(data, 0, i)
>>> [heappop(data) for i in range(len(data))]
[1, 2, 3, 4, 5, 7, 10, 18, 19, 37]

it costs O(n) to index and O(logn) to update. heapify is another solution, but less efficient than _siftup or _siftdown.

But _siftup and _siftdown are protected member in heapq, so they are not recommended to access from outside.

So is there a better and more efficient way to solve this problem? Best practice for this situation?

Thanks for reading, I really appreciate it to help me out. : )

already refer to heapq python - how to modify values for which heap is sorted, but no answer to my problem

like image 716
recnac Avatar asked Mar 27 '19 09:03

recnac


3 Answers

The answer from @cglacet is completely wrong, but looks very legit. The code snippet he provided is completely broken! It is also very hard to read. _siftup() is called n//2 times in heapify() so it cannot be faster than _siftup() by itself.

To answer the original question, there is no better way. If you are concerned about the methods being private, create your own that do the same thing.

The only thing I agree with is, that if you don't need to read from the heap for a long time, it might be beneficial to lazy heapify() it once you need them. The question is if you should use a heap for that then.

Let's go over the problems with his snippet:

The heapify() function gets called multiple times for the "update" run. The chain of errors that lead to this is as follows:

  • He passes heap_fix, but expects heap and the same is true for sort
  • If self.sort is always False, the self.heap is always True
  • He redefines __getitem__() and __setitem__() which are called every time _siftup() of _siftdown() assign or read something (note: these two aren't called in C, so they use __getitem__() and __setitem__())
  • If self.heap is True and __getitem__() and __setitem__() are being called, the _repair() function is called every time _siftup() or siftdown() swap elements. But the call to heapify() is done in C, so __getitem__() does not get called, and it doesn't end up in an infinite loop
  • He redefines self.sort so calling it, like he tries to do, would fail
  • He reads once, but updates an item nb_updates times, not 1:1 like he claims

I fixed the example, I tried verifying it as best as I could, but we all make mistakes. Feel free to check it yourself.

Code

import time
import random

from heapq import _siftup, _siftdown, heapify, heappop

class UpdateHeap(list):
    def __init__(self, values):
        super().__init__(values)
        heapify(self)

    def update(self, index, value):
        old, self[index] = self[index], value
        if value > old:
            _siftup(self, index)
        else:
            _siftdown(self, 0, index)

    def pop(self):
        return heappop(self)

class SlowHeap(list):
    def __init__(self, values):
        super().__init__(values)
        heapify(self)
        self._broken = False
        
    # Solution 2 and 3) repair using sort/heapify in a lazy way:
    def update(self, index, value):
        super().__setitem__(index, value)
        self._broken = True
    
    def __getitem__(self, index):
        if self._broken:
            self._repair()
            self._broken = False
        return super().__getitem__(index)

    def _repair(self):
        ...

    def pop(self):
        if self._broken:
            self._repair()
        return heappop(self)

class HeapifyHeap(SlowHeap):

    def _repair(self):
        heapify(self)


class SortHeap(SlowHeap):

    def _repair(self):
        self.sort()

def rand_update(heap):
    index = random.randint(0, len(heap)-1)
    new_value = random.randint(max_int+1, max_int*2)
    heap.update(index, new_value)
    
def rand_updates(update_count, heap):
    for i in range(update_count):
        rand_update(heap)
        heap[0]
        
def verify(heap):
    last = None
    while heap:
        item = heap.pop()
        if last is not None and item < last:
            raise RuntimeError(f"{item} was smaller than last {last}")
        last = item

def run_perf_test(update_count, data, heap_class):
    test_heap = heap_class(data)
    t0 = time.time()
    rand_updates(update_count, test_heap)
    perf = (time.time() - t0)*1e3
    verify(test_heap)
    return perf


results = []
max_int = 500
update_count = 100

for i in range(2, 7):
    test_size = 10**i
    test_data = [random.randint(0, max_int) for _ in range(test_size)]

    perf = run_perf_test(update_count, test_data, UpdateHeap)
    results.append((test_size, "update", perf))
    
    perf = run_perf_test(update_count, test_data, HeapifyHeap)
    results.append((test_size, "heapify", perf))

    perf = run_perf_test(update_count, test_data, SortHeap)
    results.append((test_size, "sort", perf))

import pandas as pd
import seaborn as sns

dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)

sns.lineplot(
    data=dtf, 
    x="heap size", 
    y="duration (ms)", 
    hue="method",
)

Results

As you can see, the "update" method using _siftdown() and _siftup() is asymptotically faster.

You should know what your code does, and how long will it take to run. If in doubt you shouůd check. @cglaced checked how long does the execution take, but he didn't question how long should it take. If he did, he would find the two don't match up. And others fell for it.

    heap size   method  duration (ms)
0         100   update       0.219107
1         100  heapify       0.412703
2         100     sort       0.242710
3        1000   update       0.198841
4        1000  heapify       2.947330
5        1000     sort       0.605345
6       10000   update       0.203848
7       10000  heapify      32.759190
8       10000     sort       4.621506
9      100000   update       0.348568
10     100000  heapify     327.646971
11     100000     sort      49.481153
12    1000000   update       0.256062
13    1000000  heapify    3475.244761
14    1000000     sort    1106.570005

enter image description here

like image 109
Tojik Avatar answered Oct 19 '22 18:10

Tojik


TL;DR Use heapify.

One important thing that you have to keep in mind is that theoretical complexity and performances are two different things (even though they are related). In other words, implementation does matter too. Asymptotic complexities give you some lower bounds that you can see as guarantees, for example an algorithm in O(n) ensure that in the worst case scenario, you will execute a number of instructions that is linear in the input size. There are two important things here:

  1. constants are ignored, but constants matter in real life;
  2. the worst case scenario is dependent on the algorithm you consider, not only on the input.

Depending on the topic/problem you consider, the first point can be very important. In some domains, constants hidden in asymptotic complexities are so big that you can't even build inputs that are bigger than the constants (or that input wouldn't be realistic to consider). That's not the case here, but that's something you always have to keep in mind.

Giving these two observations, you can't really say: implementation B is faster than A because A is derived from a O(n) algorithm and B is derived from a O(log n) algorithm. Even if that's a good argument to start with in general, it's not always sufficient. Theoretical complexities are especially good for comparing algorithms when all inputs are equally likely to happen. In other words, when you algorithms are very generic.

In the case where you know what your use cases and inputs will be you can just test for performances directly. Using both the tests and the asymptotic complexity will give you a good idea on how your algorithm will perform (in both extreme cases and arbitrary practical cases).

That being said, lets run some performance tests on the following class that will implement three different strategies (there are actually four strategies here, but Invalidate and Reinsert doesn't seem right in your case as you'll invalidate each item as many time as you see a given word). I'll include most of my code so you can double check that I haven't messed up (you can even check the complete notebook):

from heapq import _siftup, _siftdown, heapify, heappop

class Heap(list):
  def __init__(self, values, sort=False, heap=False):
    super().__init__(values)
    heapify(self)
    self._broken = False
    self.sort = sort
    self.heap = heap or not sort

  # Solution 1) repair using the knowledge we have after every update:        
  def update(self, key, value):
    old, self[key] = self[key], value
    if value > old:
        _siftup(self, key)
    else:
        _siftdown(self, 0, key)
    
  # Solution 2 and 3) repair using sort/heapify in a lazzy way:
  def __setitem__(self, key, value):
    super().__setitem__(key, value)
    self._broken = True
    
  def __getitem__(self, key):
    if self._broken:
        self._repair()
        self._broken = False
    return super().__getitem__(key)

  def _repair(self):  
    if self.sort:
        self.sort()
    elif self.heap:
        heapify(self)

  # … you'll also need to delegate all other heap functions, for example:
  def pop(self):
    self._repair()
    return heappop(self)

We can first check that all three methods work:

data = [10, 5, 18, 2, 37, 3, 8, 7, 19, 1]

heap = Heap(data[:])
heap.update(8, 22)
heap.update(7, 4)
print(heap)

heap = Heap(data[:], sort_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)

heap = Heap(data[:], heap_fix=True)
heap[8] = 22
heap[7] = 4
print(heap)

Then we can run some performance tests using the following functions:

import time
import random

def rand_update(heap, lazzy_fix=False, **kwargs):
    index = random.randint(0, len(heap)-1)
    new_value = random.randint(max_int+1, max_int*2)
    if lazzy_fix:
        heap[index] = new_value
    else:
        heap.update(index, new_value)
    
def rand_updates(n, heap, lazzy_fix=False, **kwargs):
    for _ in range(n):
        rand_update(heap, lazzy_fix)
        
def run_perf_test(n, data, **kwargs):
    test_heap = Heap(data[:], **kwargs)
    t0 = time.time()
    rand_updates(n, test_heap, **kwargs)
    test_heap[0]
    return (time.time() - t0)*1e3

results = []
max_int = 500
nb_updates = 1

for i in range(3, 7):
    test_size = 10**i
    test_data = [random.randint(0, max_int) for _ in range(test_size)]

    perf = run_perf_test(nb_updates, test_data)
    results.append((test_size, "update", perf))
    
    perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, heap_fix=True)
    results.append((test_size, "heapify", perf))

    perf = run_perf_test(nb_updates, test_data, lazzy_fix=True, sort_fix=True)
    results.append((test_size, "sort", perf))

The results are the following:

import pandas as pd
import seaborn as sns

dtf = pd.DataFrame(results, columns=["heap size", "method", "duration (ms)"])
print(dtf)

sns.lineplot(
    data=dtf, 
    x="heap size", 
    y="duration (ms)", 
    hue="method",
)

From these tests we can see that heapify seems like the most reasonable choice, it has a decent complexity in the worst case: O(n) and perform better in practice. On the other hand, it's probably a good idea to investigate other options (like having a data structure dedicated to that particular problem, for example using bins to drop words into, then moving them from a bin to the next look like a possible track to investigate).

Important remark: this scenario (updating vs. reading ratio of 1:1) is unfavorable to both the heapify and sort solutions. So if you manage to have a k:1 ratio, this conclusion will be even clearer (you can replace nb_updates = 1 with nb_updates = k in the above code).

Dataframe details:

    heap size   method  duration in ms
0        1000   update        0.435114
1        1000  heapify        0.073195
2        1000     sort        0.101089
3       10000   update        1.668930
4       10000  heapify        0.480175
5       10000     sort        1.151085
6      100000   update       13.194084
7      100000  heapify        4.875898
8      100000     sort       11.922121
9     1000000   update      153.587103
10    1000000  heapify       51.237106
11    1000000     sort      145.306110
like image 4
cglacet Avatar answered Oct 19 '22 18:10

cglacet


Dealing with private functions

But _siftup and _siftdown are protected member in heapq, so they are not recommended to access from outside.

The code snippets are short, so you could just include them in your own code after renaming them as public functions:

def siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

def siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    siftdown(heap, startpos, pos)

Maintaining the heap invariant

How to restore the heap invariant, when one element is out-of-order?

Using the high level heap API, you could make a series of count updates and then run heapify() before doing any more heap operations. This might or might not be efficient enough for your needs.

That said, heapify() function is very fast. Interestingly, the list.sort method is even more optimized and might beat heapify() for some types of inputs.

Better data structure

need to count the frequency of words, and maintain the top k max-count words, which prepare to output at any moment. So I use heap here. When one word count++, I need update it if it is in heap.

Consider using a different data structure than heaps. At first, heaps seem well suited to the task, however finding an arbitrary entry in a heap is slow even though we can then update it quickly with siftup/siftdown.

Instead, consider keep a dictionary mapping from a word to a position in lists of words and counts. Keeping those lists sorted requires only a position swap when a count is incremented:

from bisect import bisect_left

word2pos = {}
words = []    # ordered by descending frequency
counts = []   # negated to put most common first

def tally(word):
    if word not in word2pos:
        word2pos[word] = len(word2pos)
        counts.append(-1)
        words.append(word)
    else:
        pos = word2pos[word]
        count = counts[pos]
        swappos = bisect_left(counts, count, hi=pos)
        words[pos] = swapword = words[swappos]
        counts[pos] = counts[swappos]
        word2pos[swapword] = pos
        words[swappos] = word
        counts[swappos] = count - 1
        word2pos[word] = swappos

def topwords(n):
    return [(-counts[i], words[i]) for i in range(n)]

Built in solution

There is a another "out-of-the-box" solution that might meet your need. Just use collections.Counter():

>>> from collections import Counter
>>> c = Counter()
>>> for word in 'one two one three two three three'.split():
...     c[word] += 1
...
>>> c.most_common(2)
[('three', 3), ('one', 2)]

Binary tree or Sorted Container Solution

Another way to go is to use a binary tree or sorted container. These have O(log n) insertion and removal. And they stand ready to iterate in forward or reverse order without additional computation effort.

Here is a solution using Grant Jenk's wonderful Sorted Containers package:

from sortedcontainers import SortedSet
from dataclasses import dataclass, field
from itertools import islice

@dataclass(order=True, unsafe_hash=True, slots=True)
class Entry:
    count: int = field(hash=False)
    word: str

w2e = {}                  # type: Dict[str, Entry]
ss = SortedSet()          # type: Set[Entry]   

def tally(word):
    if word not in w2e:
        entry = w2e[word] = Entry(1, word)
        ss.add(entry)
    else:
        entry = w2e[word]
        ss.remove(entry)
        entry.count += 1
        ss.add(entry)

def topwords(n):
    return list(islice(reversed(ss), n)
like image 1
Raymond Hettinger Avatar answered Oct 19 '22 18:10

Raymond Hettinger