Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Inconsistent sorting with sort()

I have the following function to count words from a string and extract the top "n":

Function

def count_words(s, n):
"""Return the n most frequently occuring words in s."""

    #Split words into list
    wordlist = s.split()

    #Count words
    counts = Counter(wordlist)

    #Get top n words
    top_n = counts.most_common(n)

    #Sort by first element, if tie by second
    top_n.sort(key=lambda x: (-x[1], x[0]))

    return top_n

So it sorts by occurance and if tied, alphabetically. Following examples:

print count_words("cat bat mat cat cat mat mat mat bat bat cat", 3)

works (shows [('cat', 4), ('mat', 4), ('bat', 3)])

print count_words("betty bought a bit of butter but the butter was bitter", 3)

does not work (shows [('butter', 2), ('a', 1), ('bitter', 1)] but should have betty instead of bitter as it they are tied and be... is before bi...)

print count_words("betty bought a bit of butter but the butter was bitter", 6)

works (shows [('butter', 2), ('a', 1), ('betty', 1), ('bitter', 1), ('but', 1), ('of', 1)] with betty before bitter as intended)

What could cause that (word-length maybe?) and how could I fix that?

like image 897
PrimuS Avatar asked Dec 08 '22 20:12

PrimuS


2 Answers

The problem is not the sort call but the most_common. The Counter is implemented as an hash table, as such the order it uses is arbitrary. When you ask for most_common(n) it will return the n most common words, and if there are ties it just decides arbitrarily which one to return!

The simplest way to solve this is to just avoid using most_common and directly use the list:

top_n = sorted(counts.items(), key=lambda x: (-x[1], x[0]))[:n]
like image 123
Bakuriu Avatar answered Dec 11 '22 10:12

Bakuriu


You are asking for the top 3, and thus you cut of the data before you could have picked items in your specific sorting order.

Rather than have most_common() pre-sort then re-sort, use a heapq to sort by your custom criteria (provided n is smaller than the number of actual buckets):

import heapq

def count_words(s, n):
    """Return the n most frequently occuring words in s."""
    counts = Counter(s.split())
    key = lambda kv: (-kv[1], kv[0])
    if n >= len(counts):
        return sorted(counts.items(), key=key)
    return heapq.nsmallest(n, counts.items(), key=key)

On Python 2, you probably want to use iteritems() rather than items() for the above calls.

This re-creates the Counter.most_common() method, but with the updated key. Like the original, using a heapq makes sure this is bound to O(NlogK) performance rather than O(NlogN) (with N the number of buckets, and K the top element count you want to see).

Demo:

>>> count_words("cat bat mat cat cat mat mat mat bat bat cat", 3)
[('cat', 4), ('mat', 4), ('bat', 3)]
>>> count_words("betty bought a bit of butter but the butter was bitter", 3)
[('butter', 2), ('a', 1), ('betty', 1)]
>>> count_words("betty bought a bit of butter but the butter was bitter", 6)
[('butter', 2), ('a', 1), ('betty', 1), ('bit', 1), ('bitter', 1), ('bought', 1)]

And a quick performance comparison (on Python 3.6.0b1):

>>> from collections import Counter
>>> from heapq import nsmallest
>>> from random import choice, randrange
>>> from timeit import timeit
>>> from string import ascii_letters
>>> sentence = ' '.join([''.join([choice(ascii_letters) for _ in range(randrange(3, 15))]) for _ in range(1000)])
>>> counts = Counter(sentence)  # count letters
>>> len(counts)
53
>>> key = lambda kv: (-kv[1], kv[0])
>>> timeit('sorted(counts.items(), key=key)[:3]', 'from __main__ import counts, key', number=100000)
2.119404911005404
>>> timeit('nsmallest(3, counts.items(), key=key)', 'from __main__ import counts, nsmallest, key', number=100000)
1.9657367869949667
>>> counts = Counter(sentence.split())  # count words
>>> len(counts)
1000
>>> timeit('sorted(counts.items(), key=key)[:3]', 'from __main__ import counts, key', number=10000)  # note, 10 times fewer
6.689963405995513
>>> timeit('nsmallest(3, counts.items(), key=key)', 'from __main__ import counts, nsmallest, key', number=10000)
2.902360848005628
like image 36
Martijn Pieters Avatar answered Dec 11 '22 08:12

Martijn Pieters