pythonz3z3py

Enumerate partial sums using Z3


My business goal to search decision that partial sum doesn't exceed specific number through really big Array of possible values. Partial sums are not distiguished (I mean A[1] + A[3] is the same as A[3] + A[1]). Following is minimal example that models the real decision but with some concerns:

import z3
nums = [6, 1, 2, 3, 4, 7, 8, 10, 11, 0, 0, 0, 0, 0, 0] # Problem #1 - padding 0
GOAL = 14 

xis = z3.Ints('x1 x2 x3 x4 x5 x6') # indexes to point inside array 'a'
sol = z3.Optimize()
sol.add(z3.Distinct(*xis)) 
sol.add( [z3.And( x >=0, x < len(nums)) for x in xis] )
# Problem #2 - how correctly apply AtLeast/AtMost statements
#sol.add( z3.AtLeast( *[x >=0  for x in xis], 1 ))
#sol.add( z3.AtMost( *[x >=0  for x in xis], 4 ))

a = z3.Array('a', z3.RealSort(), z3.IntSort())

for i, r in enumerate(nums):
    a = z3.Store(a, i, r)

s = z3.Real('S')

sol.add(z3.Sum([z3.Select(a, x) for x in xis]) == s)
sol.add(s <= GOAL) 
sol.maximize(s) # I need this to ensure to be close as possible to GOAL

while sol.check() == z3.sat:
    model = sol.model()
    print("==="*10)
    print(model)
    for x in xis: # Problem #3 - O^2 loop for exclusion
        excl = model[x]
        sol.add(z3.And(*(x != excl for x in xis)))

I've placed comments with problems that I see there:

Problem #1 and #2 Sum can be combined of 4..6 items that is why I had to add padding 0 to ensure indexes can achieve this. I know about AtLeast AtMost but have no idea how to laverage these there.

Problem #3 After each successful evaluation I need O^2 loop to force x1..x6 not to use already used indexes anymore. Is there something like NOT IN to simplify progress uniq checking?


Solution

  • The easiest way to do this would be to use the optimizer to find the maximum possible value, then add it as a constraint, and iterate for all solutions. Instead of "symbolically indexing" into a constant list, use a set of booleans for each value to indicate whether they are picked or not. This avoids the complexity of having to create extra 0's or involving arrays.

    The other "trick" to speed up is to use the optimizer to find the max value, but use a regular solver to find all the solutions. This is because the solvers are "incremental." (i.e., they can deal with addition of new constraints after a call to check. Optimizer, unfortunately, restarts, and thus can cause a slow-down.)

    Putting all these ideas together, I'd code your problem as follows:

    from z3 import *
    
    nums = [6, 1, 2, 3, 4, 7, 8, 10, 11]
    GOAL = 14
    
    picked = list(zip ([Bool('p' + str(i)) for i in range(len(nums))], nums))
    
    o = Optimize()
    
    total = 0
    for p, v in picked:
        total = If(p, total + v, total)
    
    o.add(total <= GOAL)
    o.maximize(total)
    
    res = o.check()
    if res == sat:
        m = o.model()
        s = Solver()
        s.add(total == m.evaluate(total))
        for m in all_smt(s, [p for p, _ in picked]):
            vals = []
            for p, v in picked:
                if m[p]:
                    vals += [v]
            print(vals)
    else:
        print("Optimizer said: " + str(res))
    

    This prints:

    [2, 4, 8]
    [6, 8]
    [6, 1, 3, 4]
    [6, 1, 7]
    [1, 3, 10]
    [1, 2, 11]
    [1, 2, 3, 8]
    [1, 2, 4, 7]
    [3, 11]
    [4, 10]
    [3, 4, 7]
    

    which I believe enumerates all possible "pickings" from your list such that the sum is maximized at 14.

    Note that the above code uses the function all_smt, which unfortunately does not come with z3 itself. You can find the code for that in Section 5.1 of https://theory.stanford.edu/~nikolaj/programmingz3.html.

    Limiting count

    If you want the solutions to have a certain number of elements, you can count the number of picked elements and assert that total must satisfy your requirements. Here's the modified program, limiting the counts to be between 4 and 6:

    from z3 import *
    
    nums = [6, 1, 2, 3, 4, 7, 8, 10, 11]
    GOAL = 14
    
    picked = list(zip ([Bool('p' + str(i)) for i in range(len(nums))], nums))
    
    o = Optimize()
    
    total = 0
    count = 0
    for p, v in picked:
        total = If(p, total + v, total)
        count = If(p, count + 1, count)
    
    o.add(total <= GOAL)
    o.add(And(count >= 4, count <= 6))
    o.maximize(total)
    
    res = o.check()
    if res == sat:
        m = o.model()
        s = Solver()
        s.add(total == m.evaluate(total))
        s.add(And(count >= 4, count <= 6))
        for m in all_smt(s, [p for p, _ in picked]):
            vals = []
            for p, v in picked:
                if m[p]:
                    vals += [v]
            print(vals)
    else:
        print("Optimizer said: " + str(res))
    

    This prints:

    [6, 1, 3, 4]
    [1, 2, 3, 8]
    [1, 2, 4, 7]