Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find the largest k numbers in k arrays stored across k machines

This is an interview question. I have K machines each of which is connected to 1 central machine. Each of the K machines have an array of 4 byte numbers in file. You can use any data structure to load those numbers into memory on those machines and they fit. Numbers are not unique across K machines. Find the K largest numbers in the union of the numbers across all K machines. What is the fastest I can do this?

like image 655
Aks Avatar asked Mar 26 '12 12:03

Aks


2 Answers

(This is an interesting problem because it involves parallelism. As I haven't encountered parallel algorithm optimization before, it's quite amusing: you can get away with some ridiculously high-complexity steps, because you can make up for it later. Anyway, onto the answer...)

> "What is the fastest I can do this?"

The best you can do is O(K). Below I illustrate both a simple O(K log(K)) algorithm, and the more complex O(K) algorithm.


First step:

Each computer needs enough time to read every element. This means that unless the elements are already in memory, one of the two bounds on the time is O(largest array size). If for example your largest array size varies as O(K log(K)) or O(K^2) or something, no amount of algorithmic trickery will let you go faster than that. Thus the actual best running time is O(max(K, largestArraySize)) technically.

Let us say the arrays have a max length of N, which is <=K. With the above caveat, we're allowed to bound N<K since each computer has to look at each of its elements at least once (O(N) preprocessing per computer), each computer can pick the largest K elements (this is known as finding kth-order-statistics, see these linear-time algorithms). Furthermore, we can do so for free (since it's also O(N)).


Bounds and reasonable expectations:

Let's begin by thinking of some worst-case scenarios, and estimates for the minimum amount of work necessary.

  • One minimum-work-necessary estimate is O(K*N/K) = O(N), because we need to look at every element at the very least. But, if we're smart, we can distribute the work evenly across all K computers (hence the division by K).
  • Another minimum-work-necessary estimate is O(N): if one array is larger than all elements on all other computers, we return the set.
  • We must output all K elements; this is at least O(K) to print them out. We can avoid this if we are content merely knowing where the elements are, in which case the O(K) bound does not necessarily apply.

Can this bound of O(N) be achieved? Let's see...


Simple approach - O(NlogN + K) = O(KlogK):

For now let's come up with a simple approach, which achieves O(NlogN + K).

Consider the data arranged like so, where each column is a computer, and each row is a number in the array:

computer: A  B  C  D  E  F  G
      10 (o)      (o)          
       9  o (o)         (o)    
       8  o    (o)             
       7  x     x    (x)        
       6  x     x          (x)  
       5     x     ..........     
       4  x  x     ..          
       3  x  x  x  . .          
       2     x  x  .  .        
       1     x  x  .           
       0     x  x  .           

You can also imagine this as a sweep-line algorithm from computation geometry, or an efficient variant of the 'merge' step from mergesort. The elements with parentheses represent the elements with which we'll initialize our potential "candidate solution" (in some central server). The algorithm will converge on the correct o responses by dumping the (x) answers for the two unselected os.

Algorithm:

  • All computers start as 'active'.
  • Each computer sorts its elements. (parallel O(N logN))
  • Repeat until all computers are inactive:
    • Each active computer finds the next-highest element (O(1) since sorted) and gives it to the central server.
    • The server smartly combines the new elements with the old K elements, and removes an equal number of the lowest elements from the combined set. To perform this step efficiently, we have a global priority queue of fixed size K. We insert the new potentially-better elements, and bad elements fall out of the set. Whenever an element falls out of the set, we tell the computer which sent that element to never send another one. (Justification: This always raises the smallest element of the candidate set.)

(sidenote: Adding a callback hook to falling out of a priority queue is an O(1) operation.)

We can see graphically that this will perform at most 2K*(findNextHighest_time + queueInsert_time) operations, and as we do so, elements will naturally fall out of the priority queue. findNextHighest_time is O(1) since we sorted the arrays, so to minimize 2K*queueInsert_time, we choose a priority queue with an O(1) insertion time (e.g. a Fibonacci-heap based priority queue). This gives us an O(log(queue_size)) extraction time (we cannot have O(1) insertion and extraction); however, we never need to use the extract operation! Once we are done, we merely dump the priority queue as an unordered set, which takes O(queue_size)=O(K) time.

We'd thus have O(N log(N) + K) total running time (parallel sorting, followed by O(K)*O(1) priority queue insertions). In the worst case of N=K, this is O(K log(K)).


The better approach - O(N+K) = O(K):

However I have come up with a better approach, which achieves O(K). It is based on the median-of-median selection algorithm, but parallelized. It goes like this:

We can eliminate a set of numbers if we know for sure that there are at least K (not strictly) larger numbers somewhere among all the computers.

Algorithm:

  • Each computer finds the sqrt(N)th highest element of its set, and splits the set into elements < and > it. This takes O(N) time in parallel.
  • The computers collaborate to combine those statistics into a new set, and find the K/sqrt(N)th highest element of that set (let's call it the 'superstatistic'), and note which computers have statistics < and > the superstatistic. This takes O(K) time.
  • Now consider all elements less than their computer's statistics, on computers whose statistic is less than the superstatistic. Those elements can be eliminated. This is because the elements greater than their computer's statistic, on computers whose statistic is larger than the superstatistic, are a set of K elements which are larger. (See the visual here).
  • Now, the computers with the uneliminated elements evenly redistribute their data to the computers who lost data.
  • Recurse: you still have K computers, but the value of N has decreased. Once N is less than a predetermined constant, use the previous algorithm I mentioned in "simple approach - O(NlogN + K)"; except in this case, it is now O(K). =)

It turns out that the reductions are O(N) total (amazingly not order K), except perhaps the final step which might by O(K). Thus this algorithm is O(N+K) = O(K) total.

Analysis and simulation of O(K) running time below. The statistics allow us to divide the world into four unordered sets, represented here as a rectangle divided into four subboxes:

         ------N-----

         N^.5            
         ________________                     
|       |     s          |  <- computer
|       | #=K s  REDIST. |  <- computer
|       |     s          |  <- computer
| K/N^.5|-----S----------|  <- computer
|       |     s          |  <- computer
K       |     s          |  <- computer
|       |     s  ELIMIN. |  <- computer
|       |     s          |  <- computer
|       |     s          |  <- computer
|       |_____s__________|  <- computer

LEGEND:
s=statistic, S=superstatistic
#=K -- set of K largest elements    

(I'd draw the relation between the unordered sets of rows and s-column here, but it would clutter things up; see the addendum right now quickly.)

For this analysis, we will consider N as it decreases.

At a given step, we are able to eliminate the elements labelled ELIMIN; this has removed area from the rectangle representation above, reducing the problem size from K*N to enter image description here, which hilariously simplifies to enter image description here

Now, the computers with the uneliminated elements redistribute their data (REDIST rectangle above) to the computers with eliminated elements (ELIMIN). This is done in parallel, where the bandwidth bottleneck corresponds to the length of the short size of REDIST (because they are outnumbered by the ELIMIN computers which are waiting for their data). Therefore the data will take as long to transfer as the long length of the REDIST rectangle (another way of thinking about it: K/√N * (N-√N) is the area, divided by K/√N data-per-time, resulting in O(N-√N) time).

Thus at each step of size N, we are able to reduce the problem size to K(2√N-1), at the cost of performing N + 3K + (N-√N) work. We now recurse. The recurrence relation which will tell us our performance is:

T(N) = 2N+3K-√N + T(2√N-1)

The decimation of the subproblem size is much faster than the normal geometric series (being √N rather than something like N/2 which you'd normally get from common divide-and-conquers). Unfortunately neither the Master Theorem nor the powerful Akra-Bazzi theorem work, but we can at least convince ourselves it is linear via a simulation:

>>> def T(n,k=None):
...      return 1 if n<10 else sqrt(n)*(2*sqrt(n)-1)+3*k+T(2*sqrt(n)-1, k=k)
>>> f = (lambda x: x)
>>> (lambda n: T((10**5)*n,k=(10**5)*n)/f((10**5)*n) - T(n,k=n)/f(n))(10**30)
-3.552713678800501e-15

The function T(N) is, at large scales, a multiple of the linear function x, hence linear (doubling the input doubles the output). This method, therefore, almost certainly achieves the bound of O(N) we conjecture. Though see the addendum for an interesting possibility.


...


Addendum

  • One pitfall is accidentally sorting. If we do anything which accidentally sorts our elements, we will incur a log(N) penalty at the least. Thus it is better to think of the arrays as sets, to avoid the pitfall of thinking that they are sorted.
  • Also we might initially think that with the constant amount of work at each step of 3K, so we would have to do work 3Klog(log(N)) work. But the -1 has a powerful role to play in the decimation of the problem size. It is very slightly possible that the running time is actually something above linear, but definitely much smaller than even Nlog(log(log(log(N)))). For example it might be something like O(N*InverseAckermann(N)), but I hit the recursion limit when testing.
  • The O(K) is probably only due to the fact that we have to print them out; if we are content merely knowing where the data is, we might even be able to pull off an O(N) (e.g. if the arrays are of length O(log(K)) we might be able to achieve O(log(K)))... but that's another story.
  • The relation between the unordered sets is as follows. Would have cluttered things up in explanation.

.

          _
         / \
(.....) > s > (.....)
          s
(.....) > s > (.....)
          s
(.....) > s > (.....)
         \_/

          v

          S

          v

         / \
(.....) > s > (.....)
          s
(.....) > s > (.....)
          s
(.....) > s > (.....)
         \_/
like image 166
ninjagecko Avatar answered Oct 03 '22 03:10

ninjagecko


  • Find the k largest numbers on each machine. O(n*log(k))
  • Combine the results (on a centralized server, if k is not huge, otherwise you can merge them in a tree-hierarchy accross the server cluster).

Update: to make it clear, the combine step is not a sort. You just pick the top k numbers from the results. There are many ways to do this efficiently. You can use a heap for example, pushing the head of each list. Then you can remove the head from the heap and push the head from the list the element belonged to. Doing this k times gives you the result. All this is O(k*log(k)).

like image 34
Karoly Horvath Avatar answered Oct 03 '22 02:10

Karoly Horvath