z3smtz3py

Bug in documentation example for maximal satisfying subsets finder in z3


I am trying to use the code from a z3 documentation example to find maximally satisfying subsets in z3. Here's the code I copied:

from z3 import *

def main():
    x, y = Reals('x y')
    soft_constraints = [x > 2, x < 1, x < 0, Or(x + y > 0, y < 0), Or(y >= 0, x >= 0), Or(y < 0, x < 0), Or(y > 0, x < 0)]
    hard_constraints = BoolVal(True)
    solver = MSSSolver(hard_constraints, soft_constraints)
    for lits in enumerate_sets(solver):
        print("%s" % lits)


def enumerate_sets(solver):
    while True:
        if sat == solver.s.check():
           MSS = solver.grow()
           yield MSS
        else:
           break

class MSSSolver:
   s = Solver()
   varcache = {}
   idcache = {}

   def __init__(self, hard, soft):
       self.n = len(soft)
       self.soft = soft
       self.s.add(hard)       
       self.soft_vars = set([self.c_var(i) for i in range(self.n)])
       self.orig_soft_vars = set([self.c_var(i) for i in range(self.n)])
       self.s.add([(self.c_var(i) == soft[i]) for i in range(self.n)])

   def c_var(self, i):
       if i not in self.varcache:
          v = Bool(str(self.soft[abs(i)]))
          self.idcache[v] = abs(i)
          if i >= 0:
             self.varcache[i] = v
          else:
             self.varcache[i] = Not(v)
       return self.varcache[i]

   # Retrieve the latest model
   # Add formulas that are true in the model to 
   # the current mss

   def update_unknown(self):
       self.model = self.s.model()
       new_unknown = set([])
       for x in self.unknown:
           if is_true(self.model[x]):
              self.mss.append(x)
           else:
              new_unknown.add(x)
       self.unknown = new_unknown
      
   # Create a name, propositional atom,
   #  for formula 'fml' and return the name.

   def add_def(self, fml):
       name = Bool("%s" % fml)
       self.s.add(name == fml)
       return name

   def relax_core(self, Fs):
       assert(Fs <= self.soft_vars)
       prefix = BoolVal(True)
       self.soft_vars -= Fs
       Fs = [ f for f in Fs ]
       for i in range(len(Fs)-1):
           prefix = self.add_def(And(Fs[i], prefix))
           self.soft_vars.add(self.add_def(Or(prefix, Fs[i+1])))

   # Resolve literals from the core that 
   # are 'explained', e.g., implied by 
   # other literals.

   def resolve_core(self, core):
       new_core = set([])
       for x in core:
           if x in self.mcs_explain:
              new_core |= self.mcs_explain[x]
           else:
              new_core.add(x)
       return new_core


   # Given a current satisfiable state
   # Extract an MSS, and ensure that currently 
   # encountered cores are avoided in next iterations
   # by weakening the set of literals that are
   # examined in next iterations.
   # Strengthen the solver state by enforcing that
   # an element from the MCS is encountered.

   def grow(self):
       self.mss = []
       self.mcs = []
       self.nmcs = []
       self.mcs_explain = {}
       self.unknown = self.soft_vars
       self.update_unknown()
       cores = []
       while len(self.unknown) > 0:
           x = self.unknown.pop()
           is_sat = self.s.check(self.mss + [x] + self.nmcs)
           if is_sat == sat:
              self.mss.append(x)
              self.update_unknown()
           elif is_sat == unsat:
              core = self.s.unsat_core()
              core = self.resolve_core(core)
              self.mcs_explain[Not(x)] = {y for y in core if not eq(x,y)}
              self.mcs.append(x)
              self.nmcs.append(Not(x)) 
              cores += [core]              
           else:
              print("solver returned %s" % is_sat)
              exit()
       mss = [x for x in self.orig_soft_vars if is_true(self.model[x])]
       mcs = [x for x in self.orig_soft_vars if not is_true(self.model[x])]
       self.s.add(Or(mcs))
       core_literals = set([])
       cores.sort(key=lambda element: len(element))
       for core in cores:
           if len(core & core_literals) == 0:
              self.relax_core(core)
              core_literals |= core
       return mss

and here's some other code:

def all_smt(s, initial_terms):
    """
    s: a solver (with maybe some constraints
    t: a list of z3 terms

    From: https://stackoverflow.com/questions/11867611/z3py-checking-all-
    solutions-for-equation/70656700#70656700
    """
    def block_term(s, m, t):
        s.add(t != m.eval(t, model_completion=True))
    def fix_term(s, m, t):
        s.add(t == m.eval(t, model_completion=True))
    def all_smt_rec(terms):
        if sat == s.check():
           m = s.model()
           yield m
           for i in range(len(terms)):
               s.push()
               block_term(s, m, terms[i])
               for j in range(i):
                   fix_term(s, m, terms[j])
               yield from all_smt_rec(terms[i:])
               s.pop()
    yield from all_smt_rec(list(initial_terms))

The output of the main function looks fine, however I am interested in boolean problems like the following:

p, q = z3.Bools('p q')

hard = z3.Or(q, p)
soft = [z3.Or(p, q), z3.Not(z3.And(q, p)), z3.Not(z3.And(p, q))]

solver = MSSSolver(hard, soft)     
mms = tuple(enumerate_sets(solver))
intersection = reduce(lambda x, y: x & set(y), mms, set(mms[0]))

print(intersection, '\n')
s = z3.Solver()
s.add(intersection)
for i in all_smt(s, [p, q]):
    print(i)

which prints out

{Or(p, q), Not(And(q, p)), Not(And(p, q))} 

[Not(And(p, q)) = True]
[p = True, Not(And(p, q)) = True]
[p = True, q = True, Not(And(p, q)) = True]
[q = True, p = False, Not(And(p, q)) = True]

On the other hand the following:

s = z3.Solver()
s.add({z3.Or(p, q), z3.Not(z3.And(q, p)), z3.Not(z3.And(p, q))})
for i in all_smt(s, [p, q]):
    print(i)

prints out

[q = True, p = False]
[p = True, q = False]

In principle, the two should be equivalent, however they have different outputs. In particular, the latter is correct while the former produces more models than there are actual satisfying models. I don't quite understand the internals of this code, so I find it hard to tell where the difference is.

Just to confirm that something weird is going on, the following:

p, q = z3.Bools('p q')
s = z3.Solver()
s.add(list(intersection)[1])
s.add(p)
s.add(q)
if s.check():
    print(s.model())

prints out:

[Not(And(q, p)) = True, p = True, q = True]

---- addition -----

The following comparison might be helpful:

print([x.sexpr() for x in intersection])
print([x.sexpr() for x in {z3.Or(p, q), z3.Not(z3.And(q, p)), z3.Not(z3.And(p, q))}])

which prints:

['|Or(p, q)|', '|Not(And(q, p))|', '|Not(And(p, q))|']
['(not (and p q))', '(not (and q p))', '(or p q)']

Solution

  • In SMTLib, a variable name printed between vertical bars is a single atomic unit. (This allows variable names to have arbitrary characters in it.) To wit, notice what we get in the solver:

    mms = tuple(enumerate_sets(solver))
    intersection = reduce(lambda x, y: x & set(y), mms, set(mms[0]))
    s = z3.Solver()
    s.add(intersection)
    print(s.sexpr())
    

    This prints:

    (declare-fun |Or(p, q)| () Bool)
    (declare-fun |Not(And(q, p))| () Bool)
    (declare-fun |Not(And(p, q))| () Bool)
    (assert |Or(p, q)|)
    (assert |Not(And(q, p))|)
    (assert |Not(And(p, q))|)
    

    And you can see that the names |Or(p, q)| being created; which is not what you intended.

    I think the confusion here is that you're trying to treat the result of the call to enumerate_sets as a list of z3-expressions. But they are not: They are simply boolean literals that correspond to the underlying terms. They simply used the term itself as the name of that term, and gave it a boolean name. (This explains the vertical bars in the print out.)

    Admittedly this is very confusing. I guess the original authors didn't intend the results to be used back as expressions. To do what you want, you'll have to keep track of those terms as the mss is constructed, or you'll have to "parse" back the resulting representations. A quick and dirty solution is to simply eval them back:

    solver = MSSSolver(hard, soft)
    mms_terms = reduce(lambda x, y: x + y, list(enumerate_sets(solver)), [])
    mms = [eval(t.sexpr().replace("|", "")) for t in mms_terms]
    s = z3.Solver()
    s.add(mms)
    print(s.sexpr())
    
    for i in all_smt(s, [p, q]):
        print(i)
    

    This prints:

    (declare-fun q () Bool)
    (declare-fun p () Bool)
    (assert (or p q))
    (assert (not (and q p)))
    (assert (not (and p q)))
    
    [p = True, q = False]
    [p = False, q = True]
    

    which I think is what you were expecting. (Note that we had to filter out the | from the representations to make sure they are valid z3 Python constructions; as discussed above.)