I need to identify and consolidate all intersecting sets so that I ultimately have completely discrete sets that share no values between them. The sets currently exist as values in a dictionary, and the dictionary keys are sorted by priority that needs to be preserved.
For example, starting with the following sets in dictionary d
:
d = {'b': {'b', 'f', 'a'},
'x': {'x'},
's': {'s'},
'a': {'a', 'f', 'e'},
'e': {'e'},
'f': {'f'},
'z': {'x', 'z'},
'g': {'g'}}
...I'm trying to consolidate the sets to:
{'b': {'a', 'b', 'e', 'f'},
'x': {'x', 'z'},
's': {'s'},
'a': {'a', 'b', 'e', 'f'},
'e': {'a', 'b', 'e', 'f'},
'f': {'a', 'b', 'e', 'f'},
'z': {'x', 'z'},
'g': {'g'}}
...by consolidating any sets that overlap with other sets.
I have code that gets me to these results:
d_size = {k:len(v) for (k, v) in d.items()}
static = False
while not static:
for (k, vs) in d.items():
for v in vs:
if v == k: continue
d[v].update(vs)
static = True
for (k, v) in d_size.items():
if len(d[k]) != v:
d_size[k] = len(d[k])
static = False
. . . but it is prohibitively slow for the sizes of datasets it needs to handle which regularly extend into a few hundred thousand rows with set sizes of arbitrary lengths but usually less than 50 after the consolidations are complete.
Ultimately I then need to remove the duplicate sets, ensuring that I keep the first appearance in the dictionary as the keys are sorted by priority, so the final result for this toy example set would be:
{'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}
I only need this final dictionary for my purposes, so am open to any solutions that don't produce the intermediary dictionary with duplicate sets.
Lastly, these results are mapped back into a pandas dataframe, so am open to using solutions incorporating pandas (or numpy). Any other third-party packages would need to be balanced by weighing their footprint against their benefit.
This is a graph problem, you can use networkx.connected_components
:
# pip install networkx
import networkx as nx
# make graph
G = nx.from_dict_of_lists(d)
# identify connected components
sets = {n: c for c in nx.connected_components(G) for n in c}
# keep original order
out = {n: sets[n] for n in d}
Output:
{'b': {'a', 'b', 'e', 'f'},
'x': {'x', 'z'},
's': {'s'},
'a': {'a', 'b', 'e', 'f'},
'e': {'a', 'b', 'e', 'f'},
'f': {'a', 'b', 'e', 'f'},
'z': {'x', 'z'},
'g': {'g'}}
If you just need the first key, you could replace the last step with:
out = {}
seen = set()
for k in d:
if k not in seen:
out[k] = sets[k]
seen.update(sets[k])
Alternatively use min
with a dictionary of weights to identify the first key per group:
import networkx as nx
# make graph
G = nx.from_dict_of_lists(d)
weights = {k: i for i, k in enumerate(d)}
# {'b': 0, 'x': 1, 's': 2, 'a': 3, 'e': 4, 'f': 5, 'z': 6, 'g': 7}
out = {min(c, key=weights.get): c for c in nx.connected_components(G)}
Output:
{'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}
out = {}
mapper = {}
seen = set()
for k, s in d.items():
if (common := s & seen):
out[mapper[next(iter(common))]].update(s)
else:
out[k] = s
mapper.update(dict.fromkeys(s, k))
seen.update(s)
print(out)
# {'b': {'a', 'b', 'e', 'f'}, 'x': {'x', 'z'}, 's': {'s'}, 'g': {'g'}}