pythonpython-3.xlistdictionaryset-intersection

Preserve list order that comes from the intersection of more than two lists


I have the following dictionary

di = {
    'A': [['A1', 'A1a'], ['A1', 'A1a'], ['A1', 'A1a'], ['A1', 'A1a'], ['A1', 'A1a']], 
    'B': [['A1', 'BT', 'B2', 'B2a', 'B2a1a']], 
    'G': [['A1', 'BT', 'CT0', 'CF', 'F5', 'GHIJK', 'G4', 'G21', 'G2a', 'G2a2b', 'G2a2b2b', 'G2a2b2b1a'],
          ['A1', 'BT', 'CT0', 'CF', 'F5', 'GHIJK', 'G4', 'G21', 'G2a'],
          ['A1', 'BT', 'CT0', 'CF', 'F5', 'GHIJK', 'G4', 'G21', 'G2a', 'G2a2b', 'G2a2b2a', 'G2a2b2a1a1b', 'G2a2b2a1a1b1', 'G2a2b2a1a1b1a2']]
}

Note that I have more than two lists for two of these keys, and only one list for B.

For each key I want to return the intersection for each of the lists. So that it looks like the following:

intersection = {
    'A': ['A1', 'A1a'], 
    'B': ['A1', 'BT', 'B2', 'B2a', 'B2a1a'], 
    'G': ['A1', 'BT', 'CT0', 'CF', 'F5', 'GHIJK', 'G4', 'G21', 'G2a']
}

I was originally using the following to get the intersection

out = {}
for key in di:
  out[key] = set.intersection(*map(set, di[key]))

And it works, but it loses the ordering. The last element in each of those lists is the 'last common ancestor' or lca. My goal is to find the lca for each of the key value pairs across many individuals. For example, the key-value G looks like this:

out['G']
{'A1', 'BT', 'CF', 'CT0', 'F5', 'G21', 'G2a', 'G4', 'GHIJK'}

Which is wrong, because clearly the lca should be G2a, not GHIJK. After doing some digging it seems this comes from using sets. I have found lots about finding the intersection between two lists that don't require me to convert it to a set, but is there a way to do this for more than two lists?


Solution

  • First create the intersection of all the sublists. Then use a list comprehension of the first sublist to find all the elements that are in the intersection, so it will be in the original order.

    out = {}
    for key, (list1, *rest) in di.items():
        intersection = set(list1).intersection(*rest)
        out[key] = [i for i in list1 if i in intersection]
    pprint.pprint(out)
    {'A': ['A1', 'A1a'],
     'B': ['A1', 'BT', 'B2', 'B2a', 'B2a1a'],
     'G': ['A1', 'BT', 'CT0', 'CF', 'F5', 'GHIJK', 'G4', 'G21', 'G2a']}