Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

n-largest elements in an sequence (need to retain duplicates)

I need to find the n largest elements in a list of tuples. Here is an example for top 3 elements.

# I have a list of tuples of the form (category-1, category-2, value)
# For each category-1, ***values are already sorted descending by default***
# The list can potentially be approximately a million elements long.
lot = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4',  8), ('a', 'x5', 8), ('a', 'x6', 7),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), 
       ('b', 'x4',  7), ('b', 'x5', 6), ('b', 'x6', 5)]

# This is what I need. 
# A list of tuple with top-3 largest values for each category-1
ans = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4', 8), ('a', 'x5', 8),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

I tried using heapq.nlargest. However it only returns the first 3 largest elements and doesn't return duplicates. For example,

heapq.nlargest(3, [10, 10, 10, 9, 8, 8, 7, 6])
# returns
[10, 10, 10]
# I need
[10, 10, 10, 9, 8, 8]

I can only think of a brute force approach. This is what I have and it works.

res, prev_t, count = [lot[0]], lot[0], 1
for t in lot[1:]:
    if t[0] == prev_t[0]:
        count = count + 1 if t[2] != prev_t[2] else count
        if count <= 3:
            res.append(t)   
    else:
        count = 1
        res.append(t)
    prev_t = t

print res

Any other ideas on how I can implement this?

EDIT: timeit results for a list of 1 million elements show that mhyfritz's solution runs in 1/3rd the time of brute force. Didn't want to make the question too long. So added more details in my answer.

like image 272
Praveen Gollakota Avatar asked Jul 12 '11 19:07

Praveen Gollakota


3 Answers

I take it from your code snippet that lot is grouped w.r.t. category-1. Following should work then:

from itertools import groupby, islice
from operator import itemgetter

ans = []
for x, g1 in groupby(lot, itemgetter(0)):
    for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3):
        ans.extend(list(g2))

print ans
# [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8),
#  ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]
like image 169
mhyfritz Avatar answered Nov 17 '22 14:11

mhyfritz


If you already have the input data sorted that way then is very probably that your solution is a little better than the heapq based one.

Your algorithm complexity is O(n) while the heapq based one is conceptually O(n * log(3)) and it will probably need more passes over the data to arrange it properly.

like image 20
Mihai Toader Avatar answered Nov 17 '22 14:11

Mihai Toader


Some additional details ... I timed both mhyfritz's excellent solution that uses itertools and and my code (brute-force).

Here are the timeit results for n = 10 and for a list with 1 million elements.

# Here's how I built the sample list of 1 million entries.
lot = []
for i in range(1001):
    for j in reversed(range(333)):
        for k in range(3):
            lot.append((i, 'x', j))

# timeit Results for n = 10
brute_force = 6.55s
itertools = 2.07s
# clearly the itertools solution provided by mhyfritz is much faster.

In case anyone is curious, here is a trace of how his code works.

+ Outer loop - x, g1
| a [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8), ('a', 'x6', 7)]
+-- Inner loop - y, g2
  |- 10 [('a', 'x1', 10)]
  |- 9 [('a', 'x2', 9), ('a', 'x3', 9)]
  |- 8 [('a', 'x4', 8), ('a', 'x5', 8)]
+ Outer loop - x, g1
| b [('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), ('b', 'x4', 7), ('b', 'x5', 6), ('b', 'x6', 5)]
+-- Inner loop - y, g2
  |- 10 [('b', 'x1', 10)]
  |- 9 [('b', 'x2', 9)]
  |- 8 [('b', 'x3', 8)]
like image 1
Praveen Gollakota Avatar answered Nov 17 '22 13:11

Praveen Gollakota