pythonalgorithmdata-structuresdisjoint-setsdisjoint-union

How to properly implement disjoint set data structure for finding spanning forests in Python?


Recently, I was trying to implement the solutions of google kickstater's 2019 programming questions and tried to implement Round E's Cherries Mesh by following the analysis explanation. Here is the link to the question and the analysis. https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721

Here is the code I implemented:

t = int(input())
for k in range(1,t+1):
    n, q = map(int,input().split())
    se = list()
    for _ in range(q):
        a,b = map(int,input().split())
        se.append((a,b))
    l = [{x} for x in range(1,n+1)]
    #print(se)
    for s in se:
        i = 0
        while ({s[0]}.isdisjoint(l[i])):
            i += 1
        j = 0
        while ({s[1]}.isdisjoint(l[j])):
            j += 1
        if i!=j:
            l[i].update(l[j])
            l.pop(j)
        #print(l)
    count = q+2*(len(l)-1)
    print('Case #',k,': ',count,sep='')



This passes the sample case but not the test cases. To the best of my knowledge, this should be right. Am I doing something wrong?


Solution

  • Two issues:

    I would also advise to use meaningful variable names. The code is much easier to understand. One-letter variables, like t, q or s, are not very helpful.

    There are several ways to implement the Union-Find functions. Here I have defined a Node class which has those methods:

    # Implementation of Union-Find (Disjoint Set)
    class Node:
        def __init__(self):
            self.parent = self
            self.rank = 0
    
        def find(self):
            if self.parent.parent != self.parent:
                self.parent = self.parent.find()
            return self.parent
    
        def union(self, other):
            node = self.find()
            other = other.find()
            if node == other:
                return True # was already in same set
            if node.rank > other.rank:
                node, other = other, node
            node.parent = other
            other.rank = max(other.rank, node.rank + 1)
            return False # was not in same set, but now is
    
    testcount = int(input())
    for testid in range(1, testcount + 1):
        nodecount, blackcount = map(int, input().split())
        # use Union-Find data structure
        nodes = [Node() for _ in range(nodecount)]
        blackedges = []
        for _ in range(blackcount):
            start, end = map(int, input().split())
            blackedges.append((nodes[start - 1], nodes[end - 1]))
    
        # Start with assumption that all edges on MST are red:
        sugarcount = nodecount * 2 - 2
        for start, end in blackedges:
            if not start.union(end): # When edge connects two disjoint sets:
                sugarcount -= 1 # Use this black edge instead of red one
    
        print('Case #{}: {}'.format(testid, sugarcount))