So here's what I want to do: I have a list that contains several equivalence relations:
l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
And I want to union the sets that share one element. Here is a sample implementation:
def union(lis):
lis = [set(e) for e in lis]
res = []
while True:
for i in range(len(lis)):
a = lis[i]
if res == []:
res.append(a)
else:
pointer = 0
while pointer < len(res):
if a & res[pointer] != set([]) :
res[pointer] = res[pointer].union(a)
break
pointer +=1
if pointer == len(res):
res.append(a)
if res == lis:
break
lis,res = res,[]
return res
And it prints
[set([1, 2, 3, 6, 7]), set([4, 5])]
This does the right thing but is way too slow when the equivalence relations is too large. I looked up the descriptions on union-find algorithm: http://en.wikipedia.org/wiki/Disjoint-set_data_structure but I still having problem coding a Python implementation.
To implement the Union-Find in Python, we use the concept of trees. The tree's root can act as a representative, and each node will hold the reference to its parent node. The Union-Find algorithm will traverse the parent nodes to reach the root and combine two trees by attaching their roots.
In the Kruskal's Algorithm, Union Find Data Structure is used as a subroutine to find the cycles in the graph, which helps in finding the minimum spanning tree.
If the graph is already in memory in adjacency list format, then DFS is slightly simpler and faster (O(n) versus O(n alpha(n)), where alpha(n) is inverse Ackermann), but union-find can handle the edges arriving online in any order, which is sometimes useful (e.g., there are too many to fit in main memory).
O(n)
timedef indices_dict(lis):
d = defaultdict(list)
for i,(a,b) in enumerate(lis):
d[a].append(i)
d[b].append(i)
return d
def disjoint_indices(lis):
d = indices_dict(lis)
sets = []
while len(d):
que = set(d.popitem()[1])
ind = set()
while len(que):
ind |= que
que = set([y for i in que
for x in lis[i]
for y in d.pop(x, [])]) - ind
sets += [ind]
return sets
def disjoint_sets(lis):
return [set([x for i in s for x in lis[i]]) for s in disjoint_indices(lis)]
>>> lis = [(1,2),(2,3),(4,5),(6,7),(1,7)]
>>> indices_dict(lis)
>>> {1: [0, 4], 2: [0, 1], 3: [1], 4: [2], 5: [2], 6: [3], 7: [3, 4]})
indices_dict
gives a map from an equivalence # to an index in lis
. E.g. 1
is mapped to index 0
and 4
in lis
.
>>> disjoint_indices(lis)
>>> [set([0,1,3,4], set([2])]
disjoint_indices
gives a list of disjoint sets of indices. Each set corresponds to indices in an equivalence. E.g. lis[0]
and lis[3]
are in the same equivalence but not lis[2]
.
>>> disjoint_set(lis)
>>> [set([1, 2, 3, 6, 7]), set([4, 5])]
disjoint_set
converts disjoint indices into into their proper equivalences.
The O(n)
time complexity is difficult to see but I'll try to explain. Here I will use n = len(lis)
.
indices_dict
certainly runs in O(n)
time because only 1 for-loop
disjoint_indices
is the hardest to see. It certainly runs in O(len(d))
time since the outer loop stops when d
is empty and the inner loop removes an element of d
each iteration. now, the len(d) <= 2n
since d
is a map from equivalence #'s to indices and there are at most 2n
different equivalence #'s in lis
. Therefore, the function runs in O(n)
.
disjoint_sets
is difficult to see because of the 3 combined for-loops. However, you'll notice that at most i
can run over all n
indices in lis
and x
runs over the 2-tuple, so the total complexity is 2n = O(n)
I think this is an elegant solution, using the built in set functions:
#!/usr/bin/python3
def union_find(lis):
lis = map(set, lis)
unions = []
for item in lis:
temp = []
for s in unions:
if not s.isdisjoint(item):
item = s.union(item)
else:
temp.append(s)
temp.append(item)
unions = temp
return unions
if __name__ == '__main__':
l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
print(union_find(l))
It returns a list of sets.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With