big-oz3z3pybitvector

All-Different-Except Constraint in Z3


Is there a way to generate an all-different-except constraint in Z3 with only O(n) statements? I know that XCSP3 offers this.

This currently can be done like so with O(n^2) statements:

for i in range(len(G) - 1):
    s.add( [ Or(G[i] == 0, G[i] != G[j]) for j in range(i + 1, len(G)) ] )

If it matters I'm interested in comparing bit vectors.


Solution

  • Z3 does come with a Distinct predicate that ensures all elements are different, but to the best of my knowledge there is no built-in distinct-except.

    To encode this sort of a constraint, I'd use an array to keep track of the cardinality of inserted elements. Something like this:

    from z3 import *
    
    def distinct_except(G, ignored):
       if len(G) < 2:
           return BoolSort().cast(True)
    
       A = K(G[0].sort(), 0);
       for i in range(len(G)):
           A = Store(A, G[i], If(G[i] == ignored, 0, 1 + Select(A, G[i])))
    
       res = True
       for i in range(len(G)):
           res = And(res, Select(A, G[i]) <= 1)
    
       return res
    

    We simply insert the elements into an array, incrementing the count by 1 if the element is not ignored, otherwise putting in a 0. This avoids costly if-then-else nodes as the array is linearly built. We then walk over the list again and make sure we never stored anything larger than 1.

    This will keep the sizes of the expressions res and A linear, and z3 should be able to deal with it fairly efficiently. I'd like to hear if you find otherwise.

    Here're a few tests to see it in action:

    # Test: Create four variables, assert they are distinct in the above sense
    a, b, c, d = BitVecs('a b c d', 32)
    s = Solver()
    s.add(distinct_except([a, b, c, d], 0))
    
    # Clearly sat:
    print s.check()
    print s.model()
    
    # Add a constraint that a and b are equal
    # Still SAT, because we can make a and b 0
    s.add(a == b)
    print s.check()
    print s.model()
    
    # Force a to be non-zero
    s.add(a != 0)
    
    # Now we have unsat:
    print s.check()
    

    This prints:

    sat
    [b = 1024, a = 16, c = 1, d = 536870912]
    sat
    [c = 33554432, a = 0, d = 32768, b = 0]
    unsat
    

    Note that you can always use print s.sexpr() to see the expressions you build before you call s.check() to see how they grow as your input lists get larger.