python-3.xalgorithmdata-structuresdisjoint-sets

Disjoint set implementation in Python


I am relatively new to Python. I am studying Disjoint sets, and implemented it as follows:

class DisjointSet:
    def __init__(self, vertices, parent):
        self.vertices = vertices
        self.parent = parent

    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            return self.find(self.parent[item])

    def union(self, set1, set2):
        self.parent[set1] = set2

Now in the driver code:

def main():
    vertices = ['a', 'b', 'c', 'd', 'e', 'h', 'i']
    parent = {}

    for v in vertices:
        parent[v] = v

    ds = DisjointSet(vertices, parent)
    print("Print all vertices in genesis: ")
    ds.union('b', 'd')

    ds.union('h', 'b')
    print(ds.find('h')) # prints d (OK)
    ds.union('h', 'i')
    print(ds.find('i')) # prints i (expecting d)

main()

So, at first I initialized all nodes as individual disjoint sets. Then unioned bd and hb which makes the set: hbd then hi is unioned, which should (as I assumed) give us the set: ihbd. I understand that due to setting the parent in this line of union(set1, set2):

self.parent[set1] = set2

I am setting the parent of h as i and thus removing it from the set of bd. How can I achieve a set of ihbd where the order of the params in union() won't yield different results?


Solution

  • Your program is not working correctly because you have misunderstood the algorithm for disjoint set implementation. Union is implemented by modifying the parent of the root node rather than the node provided as input. As you have already noticed, blindly modifying parents of any node you receive in input will just destroy previous unions.

    Here's a correct implementation:

    def union(self, set1, set2):
        root1 = self.find(set1)
        root2 = self.find(set2)
        self.parent[root1] = root2
    

    I would also suggest reading Disjoint-set data structure for more info as well as possible optimizations.