Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Union find implementation using Python

So here's what I want to do: I have a list that contains several equivalence relations:

l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]

And I want to union the sets that share one element. Here is a sample implementation:

def union(lis):
  lis = [set(e) for e in lis]
  res = []
  while True:
    for i in range(len(lis)):
      a = lis[i]
      if res == []:
        res.append(a)
      else:
        pointer = 0 
        while pointer < len(res):
          if a & res[pointer] != set([]) :
            res[pointer] = res[pointer].union(a)
            break
          pointer +=1
        if pointer == len(res):
          res.append(a)
     if res == lis:
      break
    lis,res = res,[]
  return res

And it prints

[set([1, 2, 3, 6, 7]), set([4, 5])]

This does the right thing but is way too slow when the equivalence relations is too large. I looked up the descriptions on union-find algorithm: http://en.wikipedia.org/wiki/Disjoint-set_data_structure but I still having problem coding a Python implementation.

like image 464
thomas zhang Avatar asked Nov 22 '13 20:11

thomas zhang


People also ask

How do you implement union-find in Python?

To implement the Union-Find in Python, we use the concept of trees. The tree's root can act as a representative, and each node will hold the reference to its parent node. The Union-Find algorithm will traverse the parent nodes to reach the root and combine two trees by attaching their roots.

Which algorithm uses union-find?

In the Kruskal's Algorithm, Union Find Data Structure is used as a subroutine to find the cycles in the graph, which helps in finding the minimum spanning tree.

Is union-find faster than DFS?

If the graph is already in memory in adjacency list format, then DFS is slightly simpler and faster (O(n) versus O(n alpha(n)), where alpha(n) is inverse Ackermann), but union-find can handle the edges arriving online in any order, which is sometimes useful (e.g., there are too many to fit in main memory).


2 Answers

Solution that runs in O(n) time

def indices_dict(lis):
    d = defaultdict(list)
    for i,(a,b) in enumerate(lis):
        d[a].append(i)
        d[b].append(i)
    return d

def disjoint_indices(lis):
    d = indices_dict(lis)
    sets = []
    while len(d):
        que = set(d.popitem()[1])
        ind = set()
        while len(que):
            ind |= que 
            que = set([y for i in que 
                         for x in lis[i] 
                         for y in d.pop(x, [])]) - ind
        sets += [ind]
    return sets

def disjoint_sets(lis):
    return [set([x for i in s for x in lis[i]]) for s in disjoint_indices(lis)]

How it works:

>>> lis = [(1,2),(2,3),(4,5),(6,7),(1,7)]
>>> indices_dict(lis)
>>> {1: [0, 4], 2: [0, 1], 3: [1], 4: [2], 5: [2], 6: [3], 7: [3, 4]})

indices_dict gives a map from an equivalence # to an index in lis. E.g. 1 is mapped to index 0 and 4 in lis.

>>> disjoint_indices(lis)
>>> [set([0,1,3,4], set([2])]

disjoint_indices gives a list of disjoint sets of indices. Each set corresponds to indices in an equivalence. E.g. lis[0] and lis[3] are in the same equivalence but not lis[2].

>>> disjoint_set(lis)
>>> [set([1, 2, 3, 6, 7]), set([4, 5])]

disjoint_set converts disjoint indices into into their proper equivalences.


Time complexity

The O(n) time complexity is difficult to see but I'll try to explain. Here I will use n = len(lis).

  1. indices_dict certainly runs in O(n) time because only 1 for-loop

  2. disjoint_indices is the hardest to see. It certainly runs in O(len(d)) time since the outer loop stops when d is empty and the inner loop removes an element of d each iteration. now, the len(d) <= 2n since d is a map from equivalence #'s to indices and there are at most 2n different equivalence #'s in lis. Therefore, the function runs in O(n).

  3. disjoint_sets is difficult to see because of the 3 combined for-loops. However, you'll notice that at most i can run over all n indices in lis and x runs over the 2-tuple, so the total complexity is 2n = O(n)

like image 157
bcorso Avatar answered Sep 19 '22 16:09

bcorso


I think this is an elegant solution, using the built in set functions:

#!/usr/bin/python3

def union_find(lis):
    lis = map(set, lis)
    unions = []
    for item in lis:
        temp = []
        for s in unions:
            if not s.isdisjoint(item):
                item = s.union(item)
            else:
                temp.append(s)
        temp.append(item)
        unions = temp
    return unions



if __name__ == '__main__':
    l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
    print(union_find(l))

It returns a list of sets.

like image 44
JelteF Avatar answered Sep 16 '22 16:09

JelteF