Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the most efficient way of getting the intersection of k sorted arrays?

Given k sorted arrays what is the most efficient way of getting the intersection of these lists

Example

INPUT:

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 

Output:

[1,7]

There is a way to get the union of k sorted arrays based on what I read in the Elements of programming interviews book in nlogk time. I was wondering if there is a way to do something similar for the intersection as well

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]
    
    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))
    
    res = []
 
    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))

EDIT: obviously this is an algorithm question that I am trying to solve so I cannot use any of the inbuilt functions like set intersection etc

like image 805
identical123456 Avatar asked May 13 '21 20:05

identical123456


People also ask

What is the time complexity of the fastest algorithm to merge k sorted arrays of size n each?

Time Complexity: O(N * K * log (N*K)), Since the resulting array is of size N*K. Space Complexity: O(N * K), The output array is of size N * K.

What will be the best case complexity for merging two sorted?

The complexity is O(m log n). There are m iterations of the loop. Each insertion into a sorted array is an O(log n) operation. Therefore the overall complexity is O (m log n).


1 Answers

Yes, it is possible! I've modified your example code to do this.

My answer assumes that your question is about the algorithm - if you want the fastest-running code using sets, see other answers.

This maintains the O(n log(k)) time complexity: all the code between if lowest != elem or ary != times_seen: and unbench_all = False is O(log(k)). There is a nested loop inside the main loop (for unbenched in range(times_seen):) but this only runs times_seen times, and times_seen is initially 0 and is reset to 0 after every time this inner loop is run, and can only be incremented once per main loop iteration, so the inner loop cannot do more iterations in total than the main loop. Thus, since the code inside the inner loop is O(log(k)) and runs at most as many times as the outer loop, and the outer loop is O(log(k)) and runs n times, the algorithm is O(n log(k)).

This algorithm relies upon how tuples are compared in Python. It compares the first items of the tuples, and if they are equal it, compares the second items (i.e. (x, a) < (x, b) is true if and only if a < b). In this algorithm, unlike in the example code in the question, when an item is popped from the heap, it is not necessarily pushed again in the same iteration. Since we need to check if all sub-lists contain the same number, after a number is popped from the heap, it's sublist is what I call "benched", meaning that it is not added back to the heap. This is because we need to check if other sub-lists contain the same item, so adding this sub-list's next item is not needed right now.

If a number is indeed in all sub-lists, then the heap will look something like [(2,0),(2,1),(2,2),(2,3)], with all the first elements of the tuples the same, so heappop will select the one with the lowest sub-list index. This means that first index 0 will be popped and times_seen will be incremented to 1, then index 1 will be popped and times_seen will be incremented to 2 - if ary is not equal to times_seen then the number is not in the intersection of all sub-lists. This leads to the condition if lowest != elem or ary != times_seen:, which decides when a number shouldn't be in the result. The else branch of this if statement is for when it still might be in the result.

The unbench_all boolean is for when all sub-lists need to be removed from the bench - this could be because:

  1. The current number is known to not be in the intersection of the sub-lists
  2. It is known to be in the intersection of the sub-lists

When unbench_all is True, all the sub-lists that were removed from the heap are re-added. It is known that these are the ones with indices in range(times_seen) since the algorithm removes items from the heap only if they have the same number, so they must have been removed in order of index, contiguously and starting from index 0, and there must be times_seen of them. This means that we don't need to store the indices of the benched sub-lists, only the number that have been benched.

import heapq


def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res


if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))

An equivalent algorithm can be written like this, if you prefer:

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res

like image 112
Oli Avatar answered Sep 29 '22 15:09

Oli