pythonpython-3.xalgorithm

What's a fast way to identify all overlapping sets?


I need to identify and consolidate all intersecting sets so that I ultimately have completely discrete sets that share no values between them. The sets currently exist as values in a dictionary, and the dictionary keys are sorted by priority that needs to be preserved.

For example, starting with the following sets in dictionary d:

d = {'b': {'b', 'f', 'a'},
     'x': {'x'},
     's': {'s'},
     'a': {'a', 'f', 'e'},
     'e': {'e'},
     'f': {'f'},
     'z': {'x', 'z'},
     'g': {'g'}}

...I'm trying to consolidate the sets to:

{'b': {'a', 'b', 'e', 'f'},
 'x': {'x', 'z'},
 's': {'s'},
 'a': {'a', 'b', 'e', 'f'},
 'e': {'a', 'b', 'e', 'f'},
 'f': {'a', 'b', 'e', 'f'},
 'z': {'x', 'z'},
 'g': {'g'}}

...by consolidating any sets that overlap with other sets.

I have code that gets me to these results:

d_size = {k:len(v) for (k, v) in d.items()}
static = False
while not static:
    for (k, vs) in d.items():
        for v in vs:
            if v == k: continue
            d[v].update(vs)

    static = True
    for (k, v) in d_size.items():
        if len(d[k]) != v:
            d_size[k] = len(d[k])
            static = False

. . . but it is prohibitively slow for the sizes of datasets it needs to handle which regularly extend into a few hundred thousand rows with set sizes of arbitrary lengths but usually less than 50 after the consolidations are complete.

Ultimately I then need to remove the duplicate sets, ensuring that I keep the first appearance in the dictionary as the keys are sorted by priority, so the final result for this toy example set would be:

{'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}

I only need this final dictionary for my purposes, so am open to any solutions that don't produce the intermediary dictionary with duplicate sets.

Lastly, these results are mapped back into a pandas dataframe, so am open to using solutions incorporating pandas (or numpy). Any other third-party packages would need to be balanced by weighing their footprint against their benefit.


Solution

  • This is a graph problem, you can use networkx.connected_components:

    # pip install networkx
    import networkx as nx
    
    # make graph
    G = nx.from_dict_of_lists(d)
    
    # identify connected components
    sets = {n: c for c in nx.connected_components(G) for n in c}
    
    # keep original order
    out = {n: sets[n] for n in d}
    

    Output:

    {'b': {'a', 'b', 'e', 'f'},
     'x': {'x', 'z'},
     's': {'s'},
     'a': {'a', 'b', 'e', 'f'},
     'e': {'a', 'b', 'e', 'f'},
     'f': {'a', 'b', 'e', 'f'},
     'z': {'x', 'z'},
     'g': {'g'}}
    

    If you just need the first key, you could replace the last step with:

    out = {}
    seen = set()
    for k in d:
        if k not in seen:
            out[k] = sets[k]
            seen.update(sets[k])
    

    Alternatively use min with a dictionary of weights to identify the first key per group:

    import networkx as nx
    
    # make graph
    G = nx.from_dict_of_lists(d)
    
    weights = {k: i for i, k in enumerate(d)}
    # {'b': 0, 'x': 1, 's': 2, 'a': 3, 'e': 4, 'f': 5, 'z': 6, 'g': 7}
    
    out = {min(c, key=weights.get): c for c in nx.connected_components(G)}
    

    Output:

    {'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}
    

    pure python solution:

    out = {}
    mapper = {}
    seen = set()
    
    for k, s in d.items():
        if (common := s & seen):
            out[mapper[next(iter(common))]].update(s)
        else:
            out[k] = s
            mapper.update(dict.fromkeys(s, k))
            seen.update(s)
    
    print(out)
    # {'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}