pythonalgorithmperformancedictionaryset

Fastest way to find the least amount of subsets that sum up to the total set in Python


Say I have a dictionary of sets like this:

d = {'a': {1,2,8}, 'b': {3,1,2,6}, 'c': {0,4,1,2}, 'd': {9}, 'e': {2,5},
     'f': {4,8}, 'g': {0,9}, 'h': {7,2,3}, 'i': {5,6,3}, 'j': {4,6,8}}

Each set represents a subset of a total set s = set(range(10)). I would like an efficient algorithm to find the least amount of keys that make up the whole set, and return an empty list if it's not possible by any combinations of keys. If there are many possible combinations that have the least amount of keys to sum up to the whole set, I only need one combination and it can be any one of those.

So far I am using an exhaustive approach that checks all possible combinations and then take the combination that has the least amount of keys.

import copy
def append_combinations(combo, keys):
    for i in range(len(keys)):
        new_combo = copy.copy(combo)
        new_combo.append(keys[i])
        new_keys = keys[i+1:]
        if {n for k in new_combo for n in d[k]} == s:
            valid_combos.append(new_combo)
        append_combinations(new_combo, new_keys)

valid_combos = []
combo = []
keys = sorted(d.keys())
append_combinations(combo, keys)
sorted_combos = sorted(valid_combos, key=lambda x: len(x))
print(sorted_combos[0])
# ['a', 'c', 'd', 'h', 'i']

However, this becomes very expensive when the dictionary has many keys (in practice I will have around 100 keys). Any suggestions for a faster algorithm?


Solution

  • "Fast" is relative. For this size of input, it's quite fast to build up a sparse cover constraint matrix like this:

    import time
    
    import numpy as np
    from scipy.optimize import milp, Bounds, LinearConstraint
    from scipy.sparse import coo_array
    
    sets = {
        'a': {1,2,8}, 'b': {3,1,2,6}, 'c': {0,4,1,2}, 'd': {9}, 'e': {2,5},
        'f': {4,8}, 'g': {0,9}, 'h': {7,2,3}, 'i': {5,6,3}, 'j': {4,6,8},
    }
    m = 1 + max(v for vset in sets.values() for v in vset)
    n = len(sets)
    
    # Variables: for each set, binary assignment.
    cost = np.ones(n)  # minimize keys assigned
    bounds = Bounds(lb=0, ub=1)  # binary variables
    
    # For every value: it must be covered by at least one set
    coords = np.array(
        [
            (y, x)
            for y in range(m)
            for x, vset in enumerate(sets.values())
            if y in vset
        ],
        dtype=np.int32,
    ).T
    cover_constraint = LinearConstraint(
        A=coo_array((
            np.ones(coords.shape[1]),  # data
            coords,
        )).tocsc(),
        lb=1,
    )
    
    start = time.perf_counter()
    result = milp(
        c=cost, integrality=1, bounds=bounds, constraints=cover_constraint,
    )
    end = time.perf_counter()
    assert result.success
    
    print(result.message)
    print(f'Cost: {result.fun} in {(end - start)*1e3:.2f} ms')
    print(
        'Assigned:',
        ' '.join(
            k
            for k, assigned in zip(sets.keys(), result.x)
            if assigned
        ),
    )
    
    Optimization terminated successfully. (HiGHS Status 7: Optimal)
    Cost: 5.0 in 2.58 ms
    Assigned: c g h i j
    

    For 100 keys:

    rand = np.random.default_rng(seed=0)
    m = 10
    n = 100
    sets = {
        i: rand.choice(
            a=m, size=rand.integers(low=1, high=4, endpoint=True),
            replace=False, shuffle=False,
        )
        for i in range(n)
    }
    # ...
    print(result.message)
    print(f'Cost: {result.fun} in {(end - start)*1e3:.2f} ms')
    print(
        'Assigned:',
        ' '.join(
            str(k)
            for k, assigned in zip(sets.keys(), result.x)
            if assigned > 0.5
        ),
    )
    for (k, vset), assigned in zip(sets.items(), result.x):
        if assigned > 0.5:
            print(f'{k}: {vset}')
    
    Optimization terminated successfully. (HiGHS Status 7: Optimal)
    Cost: 3.0 in 18.49 ms
    Assigned: 0 82 99
    0: [4 7 2 3]
    82: [6 0 5 8]
    99: [1 5 4 9]
    

    It processes input of size ~1200 in about a second on my nothing-special laptop.