pythonalgorithmgraphstrongly-connected-graph

Python Cluster connnected elements with n to m relationship


This is not a homework task (please see my profile). I do not have a computer science background and this question came up in an applied machine learning problem. I am pretty sure that I am not the first person to have this problem, hence I am looking for an elegant solution. I will preferre a solution using a python library over raw implementations.

Assume we have a dictionary connecting letters and numbers as input

connected = {
    'A': [1, 2, 3],
    'B': [3, 4],
    'C': [5, 6],
}

Each letter can be connected to multiple numbers. And one number can be connected to multiple letters. But each letter can only be connected to a number once.

If we look at the dictionary we realize, that the number 3 is connected with the letter 'A' and the letter 'B' hence we can put 'A' and 'B' into a cluster. The numbers of the letter 'C' are not present in the other letters. Hence, we cannot cluster the letter 'C' any further. And the expected output should be

cluster = {
    '1': {
        'letters': ['A', 'B'],
        'numbers': [1, 2, 3, 4], 
    },
    '2': {
        'letters': ['C'],
        'numbers': [5, 6],
    }
}

I think this should be related to graph algorithms and connected subgraphs but I do not know where to start.


Solution

  • Using a union-find structure you can solve this efficiently in O(num letters + num numbers). The key idea is that you can just connect letters to their list of numbers. Once you do this for all letters, you automatically have unions (i.e. clusters) of desired property.

    class UnionFind:
        def __init__(self):
            self.id = {}
            self.size = {}
    
        def find(self, a):
            cur = a
            path = []
            while self.id[cur] != cur:
                path.append(cur)
                cur = self.id[cur]
            for x in path:
                self.id[x] = cur
            return cur
    
        def union(self, a, b):
            if a not in self.id:
                self.id[a] = a
                self.size[a] = 1
            if b not in self.id:
                self.id[b] = b
                self.size[b] = 1
    
            roota, rootb = self.find(a), self.find(b)
            if roota != rootb:
                if self.size[roota] > self.size[rootb]:
                    roota, rootb = rootb, roota
                self.id[roota] = rootb
                self.size[rootb] += self.size[roota]
    
    if __name__ == "__main__":
        from collections import defaultdict
    
        uf = UnionFind()
        connected = {
            'A': [1, 2, 3],
            'B': [3, 4],
            'C': [5, 6],
        }
        for letter, numbers in connected.items():
            for number in numbers:
                uf.union(letter, number)
        
        clusters = defaultdict(list)
        for key, cluster_id in uf.id.items():
            clusters[cluster_id].append(key)
        
        formatted_clusters = {}
        for i, cluster_elements in enumerate(clusters.values()):
            letters = [e for e in cluster_elements if isinstance(e, str)]
            numbers = [e for e in cluster_elements if not isinstance(e, str)]
            key = str(i+1)
            formatted_clusters[key] = {
                "letters": letters,
                "numbers": numbers
            }
        print(formatted_clusters)
    

    Output:

    {'1': {'letters': ['A', 'B'], 'numbers': [1, 2, 3, 4]}, '2': {'letters': ['C'], 'numbers': [5, 6]}}