pythonprofilersolversat

Slow dnf to cnf in pycosat


Question in short

To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?

Question in detail

I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.

Example Tower puzzle (solved):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).

What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?

This code is also available from this Github repository.

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()

The actual time is spent in this helper function from the used helper code that came along with the video.

def from_dnf(groups) -> 'cnf':
    'Convert from or-of-ands to and-of-ors'
    cnf = {frozenset()}
    for group_index, group in enumerate(groups, start=1):
        print(f'Group {group_index}/{len(groups)}')
        nl = {frozenset([literal]): neg(literal) for literal in group}
        # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
        # The nl check skips over identities: {x, ~x, y} -> True
        cnf = {clause | literal for literal in nl for clause in cnf
               if nl[literal] not in clause}
        # The sc check removes clauses with superfluous terms:
        #     {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
        # Should this be left until the end?
        sc = min(cnf, key=len)  # XXX not deterministic
        cnf -= {clause for clause in cnf if clause > sc}
    return list(map(tuple, cnf))

The output from pyinstrument when using a 4x4 grid shows that the line cnf = { ... } in here is the culprit:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:05:58  Samples:  146
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.515     CPU time: 0.506
/   _/                      v3.4.2

Program: ./src/towers.py

0.515 <module>  ../<string>:1
   [7 frames hidden]  .., runpy
      0.513 _run_code  runpy.py:62
      └─ 0.513 <module>  towers.py:1
         ├─ 0.501 display_solution  towers.py:64
         │  └─ 0.501 solution  towers.py:188
         │     ├─ 0.408 cnf  towers.py:101
         │     │  ├─ 0.397 from_dnf  sat_utils.py:65
         │     │  │  ├─ 0.329 <setcomp>  sat_utils.py:73
         │     │  │  ├─ 0.029 [self]
         │     │  │  ├─ 0.021 min  ../<built-in>:0
         │     │  │  │     [2 frames hidden]  ..
         │     │  │  └─ 0.016 <setcomp>  sat_utils.py:79
         │     │  └─ 0.009 [self]
         │     └─ 0.093 solve_one  sat_utils.py:53
         │        └─ 0.091 itersolve  sat_utils.py:43
         │           ├─ 0.064 translate  sat_utils.py:32
         │           │  ├─ 0.049 <listcomp>  sat_utils.py:39
         │           │  │  ├─ 0.028 [self]
         │           │  │  └─ 0.021 <listcomp>  sat_utils.py:39
         │           │  └─ 0.015 make_translate  sat_utils.py:12
         │           └─ 0.024 itersolve  ../<built-in>:0
         │                 [2 frames hidden]  ..
         └─ 0.009 <module>  typing.py:1
               [26 frames hidden]  typing, abc, ..

Solution

  • First, it's good to note the difference between equivalence and equisatisfiability. In general, converting an arbitrary boolean formula (say, something in DNF) to CNF can result in a exponential blow-up in size.

    This blow-up is the issue with your from_dnf approach: whenever you handle another product term, each of the literals in that product demands a new copy of the current cnf clause set (to which it will add itself in every clause). If you have n product terms of size k, the growth is O(k^n).

    In your case n is actually a function of k!. What's kept as a product term is filtered to those satisfying the view constraint, but overall the runtime of your program is roughly in the region of O(k^f(k!)). Even if f grows logarithmically, this is still O(k^(k lg k)) and not quite ideal!

    Because you're asking "is this satisfiable?", you don't need an equivalent formula but merely an equisatisfiable one. This is some new formula that is satisfiable if and only if the original is, but which might not be satisfied by the same assignments.

    For example, (a ∨ b) and (a ∨ c) ∧ (¬b) are each obviously satisfiable, so they are equisatisfiable. But setting b true satisfies the first and falsifies the second, so they are not equivalent. Furthermore the first doesn't even have c as a variable, again making it not equivalent to the second.

    This relaxation is enough to replace this exponential blow-up with a linear-sized translation instead.


    The critical idea is the use of extension variables. These are fresh variables (i.e., not already present in the formula) that allow us to abbreviate expressions, so we don't end up making multiple copies of them in the translation. Since the new variable is not present in the original, we'll no longer have an equivalent formula; but because the variable will be true if and only if the expression is, it will be equisatisfiable.

    If we wanted to use x as an abbreviation of y, we'd state x ≡ y. This is the same as x → y and y → x, which is the same as (¬x ∨ y) ∧ (¬y ∨ x), which is already in CNF.

    Consider the abbreviation for a product term: x ≡ (a ∧ b). This is x → (a ∧ b) and (a ∧ b) → x, which works out to be three clauses: (¬x ∨ a) ∧ (¬x ∨ b) ∧ (¬a ∨ ¬b ∨ x). In general, abbreviating a product term of k literals with x will produce k binary clauses expressing that x implies each of them, and one (k+1)-clause expressing that all together they imply x. This is linear in k.

    To really see why this helps, try converting (a ∧ b ∧ c) ∨ (d ∧ e ∧ f) ∨ (g ∧ h ∧ i) to an equivalent CNF with and without an extension variable for the first product term. Of course, we won't just stop with one term: if we abbreviate each term then the result is precisely a single CNF clause: (x ∨ y ∨ z) where these each abbreviate a single product term. This is a lot smaller!

    This approach can be used to turn any circuit into an equisatisfiable formula, linear in size and in CNF. This is called a Tseitin transformation. Your DNF formula is simply a circuit composed of a bunch of arbitrary fan-in AND gates, all feeding into a single arbitrary fan-in OR gate.

    Best of all, although this formula is not equivalent due to additional variables, we can recover an assignment for the original formula by simply dropping the extension variables. It is sort of a 'best case' equisatisfiable formula, being a strict superset of the original.


    To patch this into your code, I added:

    # Uses pseudo-namespacing to avoid collisions.
    _EXT_SUFFIX = "___"
    _NEXT_EXT_INDEX = 0
    
    
    def is_ext_var(element) -> bool:
        return element.endswith(_EXT_SUFFIX)
    
    
    def ext_var() -> str:
        global _NEXT_EXT_INDEX
        ext_index = _NEXT_EXT_INDEX
        _NEXT_EXT_INDEX += 1
    
        return intern(f"{ext_index}{_EXT_SUFFIX}")
    

    This lets us pull a fresh named variable out of thin air. Since these extension variable names don't have meaningful semantics to your solution display function, I changed:

    point_to_value = {
        point: value for point, value in [fact.split() for fact in self.solution]
    }
    

    into:

    point_to_value = {
        point: value
        for point, value in [
            fact.split() for fact in self.solution if not is_ext_var(fact)
        ]
    }
    

    There are certainly better ways to do this, this is just a patch. :)

    Reimplementing your from_dnf with the above ideas, we get:

    def from_dnf(groups) -> "cnf":
        "Convert from or-of-ands to and-of-ors, equisatisfiably"
        cnf = []
    
        extension_vars = []
        for group in groups:
            extension_var = ext_var()
            neg_extension_var = neg(extension_var)
    
            imply_ext_clause = []
            for literal in group:
                imply_ext_clause.append(neg(literal))
                cnf.append((neg_extension_var, literal))
    
            imply_ext_clause.append(extension_var)
            cnf.append(tuple(imply_ext_clause))
    
            extension_vars.append(extension_var)
    
        cnf.append(tuple(extension_vars))
        return cnf
    

    Each group gets an extension variable. Each literal in the group adds its negation into the (k+1)-sized implication clause, and becomes implied by the extension. After the literals are handled, the extension variable finalizes the remaining implication and adds itself to the list of new extension variables. Finally, at least one of these extension variables must be true.

    This change alone lets me solve this 5x5 puzzle ~instantly:

    self.visible_from_top = [3, 2, 1, 4, 2]
    self.visible_from_bottom = [2, 2, 4, 1, 2]
    self.visible_from_left = [3, 2, 3, 1, 3]
    self.visible_from_right = [2, 2, 1, 3, 2]
    self.given_numbers = {}
    

    I added some timing output as well:

    @property
    def solution(self):
        if self._solution is None:
            start_time = time.perf_counter()
    
            cnf = self.cnf
            cnf_time = time.perf_counter()
            print(f"CNF: {cnf_time - start_time}s")
    
            self._solution = solve_one(cnf)
            end_time = time.perf_counter()
            print(f"Solve: {end_time - cnf_time}s")
        return self._solution
    

    The 5x5 puzzle gives me:

    CNF: 0.00565183162689209s
    Solve: 0.005589433014392853s
    

    However, we still have that pesky k! growth when enumerating viable tower height permutations.

    I generated a 9x9 puzzle (the largest the site permits), which corresponds to:

    self.visible_from_top = [3, 3, 3, 3, 1, 4, 2, 4, 2]
    self.visible_from_bottom = [3, 1, 4, 2, 5, 3, 3, 2, 3]
    self.visible_from_left = [3, 3, 1, 2, 4, 5, 2, 3, 2]
    self.visible_from_right = [3, 1, 7, 4, 3, 3, 2, 2, 4]
    self.given_numbers = {
        "AB": 5,
        "AD": 4,
        "BD": 3,
        "BE": 2,
        "CD": 7,
        "CF": 5,
        "CG": 1,
        "DB": 1,
        "DH": 7,
        "EA": 4,
        "EI": 2,
        "FA": 2,
        "FE": 8,
        "GG": 7,
        "GI": 6,
        "HA": 3,
        "HF": 2,
        "HH": 1,
        "IG": 6,
    }
    

    This gives me:

    CNF: 28.505195066332817s
    Solve: 40.48229135945439s
    

    We should spend more time solving and less time generating, but close to half the time is generating.

    In my opinion, using DNF in a CNF-SAT translation is often† a sign of the wrong approach. Solvers are way better at exploring and learning about the solution space than we are — spending factorial amount of time pre-exploring is actually worse than the solver's exponential worse case.

    It's understandable to 'fall back' to DNF, because programmers naturally think in terms of "write an algorithm that emits solutions". But the real benefit of solvers kicks in when you encode this in the problem. Let the solver reason about conditions in which solutions become infeasible. To do this, we want to think in terms of circuits. Lucky for us, we also know how to turn a circuit into CNF quickly.

    †I said "often"; if your DNF is small and quick to produce (like a single circuit gate), or if encoding it to a circuit is prohibitively complicated, then it can be helpful to pre-compute some of the solution space.


    You've actually already done some of this! For example, we will need a circuit that counts how many times a certain number appears in a span (row or column), and an assertion that this number is exactly one. Then for each span and for each number, we'll emit this circuit. That way if a tower of size e.g. 3 appears twice in a row, the counter for that row for 3 will emit '2' and our assertion that it be '1' will not be upheld.

    Your one_of constraint is one possible implementation of this. Yours uses the 'obvious' pairwise encoding: for each location in the span, if N is present at that location then it is not present in any other location. This is actually quite a good encoding because it's comprised almost entirely of binary clauses, and SAT solvers love binary clauses (they use significantly less memory and propagate often). But for really large sets of things to count, this O(n^2) scaling can become an issue.

    You can imagine an alternative approach where you literally encode an adder circuit: each location is an input bit to the circuit, and the circuit produces n bits of output telling you the final sum (the paper above is a good read!). You then assert this sum is exactly one using unit clauses that force specific output bits.

    It may seem redundant to encode a circuit only to force some of its outputs to be a constant value. However, this is much easier to reason about and modern solvers are aware that encodings do this and optimize for it. They perform significantly more sophisticated in-processing than the initial encoding process could reasonably do. The 'art' of using solvers is in knowing and testing when these alternative encodings work better than others.

    Note that exactly_k_of is at_least_k_of along with at_most_k_of. You've noted this in your Q class == implementation. Implementing at_least_1_of is trivial, being one clause; at_most_1_of is so common it's often just called AMO. I encourage you to try implementing < and > in some of the other ways discuss in the paper (perhaps even choosing which to use based on input size) to get a feel for it.


    Turning our attention back to the k! visibility constraints, what we need is a circuit that tells us how many towers are visible from a certain direction, which we can then assert be a specific value.

    Stop and think about how this could be done, it's not easy!

    Analogous to the various one_of approaches, we can go with a 'pure' circuit for counting or use a simpler but worse-scaling pairwise approach. I have attached the sketch of the pure circuit approach at the very bottom (‡) of this answer. For now we will use the pairwise method.

    The main observation to make is that among non-visible towers, we don't care about their permutations. Consider:

    3 -> 1 5 _ _ _ 9 _ _ _
         A B C D E F G H I
    

    We see 3 towers from the left as long as the CDE group contains 2, 3, and 4, and likewise if the GHI group contains 6, 7, and 8. But the order in which they appear in the group has no implication on the visible towers.

    Rather than compute which towers are visible, we will declare which towers are visible and follow their implications. We'll be filling in this function:

    def visible_in_span(points: Collection[str], desired: int) -> "cnf":
        """Assert desired visible towers in span. Wlog, visibility is from index 0."""
        points = list(points)
        n = len(points)
        assert desired <= n
    
        cnf = []
    
        # ...
    
        return cnf
    

    Assume a fixed span and viewing direction: each location will have k associated variables, Av1 through Avk stating "this is the kth visible tower". We will also have Av ≡ (Av1 ∨ Av2 ∨ ⋯ ∨ Avk) meaning "A has a visible tower".

    In the above example, Av1, Bv2, and Fv3 are all true. There are some obvious implications to emit. At a location, at most one of these is true (you can't be both the first and second visible tower) — but not exactly one, since it's perfectly fine to have a non-visible tower. Another is that if a location is the kth visible tower, then no other location is also the kth visible tower.

    We can add this so far:

    is_kth_visible_tower_at = {}
    is_kth_visible_tower_vars = collections.defaultdict(list)
    is_visible_tower_at = {}
    for point in points:
        is_visible_tower_vars = []
        for k in range(1, n + 1):
            # Xvk
            is_kth_visible_tower_var = ext_var()
    
            is_kth_visible_tower_at[(point, k)] = is_kth_visible_tower_var
            is_kth_visible_tower_vars[k].append(is_kth_visible_tower_var)
            is_visible_tower_vars.append(is_kth_visible_tower_var)
    
        # Xv
        is_visible_tower_at_var = ext_var()
        # Xv → (Xv1 ∨ Xv2 ∨ ⋯)
        cnf.append(tuple([neg(is_visible_tower_at_var)] + is_visible_tower_vars))
        # (Xv1 ∨ Xv2 ∨ ⋯) → Xv
        for is_visible_tower_var in is_visible_tower_vars:
            cnf.append((neg(is_visible_tower_var), is_visible_tower_at_var))
    
        is_visible_tower_at[point] = is_visible_tower_at_var
    
        # At most one visible tower here.
        cnf += Q(is_visible_tower_vars) <= 1
    
    # At most one kth visible tower anywhere.
    for k in range(1, n + 1):
        cnf += Q(is_kth_visible_tower_vars[k]) <= 1
    

    Next we need ordering among visible towers, so that the kth + 1 visible tower comes after the kth visible tower. This is accomplished by the kth + 1 visible tower forcing at least one of the prior locations to be the kth visible tower. E.g., Dv3 → (Av2 ∨ Bv2 ∨ Cv2) and Cv2 → (Av1 ∨ Bv1). We know Av1 is always true which provides the base case. (If we enter a situation like needing B to be the third visible tower, that will require A be the second visible tower which contradicts Av1.)

    # Towers are ordered.
    for index, point in enumerate(points):
        if index == 0:
            cnf += basic_fact(is_kth_visible_tower_at[(point, 1)])
            continue
    
        for k in range(1, n + 1):
            # Xvk → ⋯
            implication = [neg(is_kth_visible_tower_at[(point, k)])]
    
            j = k - 1
            if j > 0:
                for index_j, point_j in enumerate(points):
                    if index_j == index:
                        break
    
                    # ⋯ ∨ Wxj ∨ ⋯
                    implication.append(is_kth_visible_tower_at[(point_j, j)])
    
            cnf.append(tuple(implication))
    

    So far so good, but we haven't related tower height to visibility. The above would allow 9 8 7 as a solution, calling 9 the first visible tower, 8 the second, and 7 the third. To solve this we want a tower placement to prohibit a smaller tower from also being visible.

    Each location will again receive a set of abbreviations indicating if it obscured below a certain height, called Ao1, Ao2, and so on. This will give us a 'grid' of implications that keep things simpler. The first is that a higher tower being obscured implies the next highest tower at the same location is also obscured, so that Ao3 → Ao2 and Ao2 → Ao1. The second is that if a tower is obscured at one location, it is also obscured at all later locations. This is Ao3 → Bo3 and Bo3 → Co3 and so on.

    is_height_obscured_at = {}
    is_height_obscured_previous = [None] * n
    for point in points:
        is_obscured_previous = None
        for k in range(1, n + 1):
            # Xok
            is_height_obscured_var = ext_var()
    
            # Wok → Xok
            is_k_obscured_previous = is_height_obscured_previous[k - 1]
            if is_k_obscured_previous is not None:
                cnf.append((neg(is_k_obscured_previous), is_height_obscured_var))
    
            # Xok → Xo(k-1)
            if is_obscured_previous is not None:
                cnf.append((neg(is_height_obscured_var), is_obscured_previous))
    
            is_height_obscured_at[(point, k)] = is_height_obscured_var
            is_height_obscured_previous[k - 1] = is_height_obscured_var
            is_obscured_previous = is_height_obscured_var
    

    From this it's easy to see that stating e.g. Bo4 implies the remaining towers equal or less than 4 in height are all obscured. We can now easily relate tower placement to obscurity: A5 → Bo4.

    # A placed tower obscures smaller later towers.
    for index, point in enumerate(points):
        if index + 1 == len(points):
            break
    
        next_point = points[index + 1]
        for k in range(2, n + 1):
            j = k - 1
    
            # Xk → Yo(k-1)
            cnf.append((neg(comb(point, k)), is_height_obscured_at[(next_point, j)]))
    

    Last, we need to relate obscurity to visibility. We'll need one final set of abbreviations, stating that a specific tower height is visible at a location. At the risk of making typos easy, we'll call this Ahv for some height h, so that Ahv ≡ (Ah ∧ Av). A concrete example would be C3v ≡ (C3 ∧ Cv): a tower of height 3 is visible at C if and only if there is a tower visible at C, and that tower is the height 3 tower.

    is_height_visible_at = {}
    for point in points:
        for k in range(1, n + 1):
            # Xhv
            height_visible_at_var = ext_var()
    
            # Xhv ≡ (Xh ∧ Xv)
            cnf.append((neg(height_visible_at_var), comb(point, k)))
            cnf.append((neg(height_visible_at_var), is_visible_tower_at[point]))
            cnf.append(
                (
                    neg(comb(point, k)),
                    neg(is_visible_tower_at[point]),
                    height_visible_at_var,
                )
            )
    
            is_height_visible_at[(point, k)] = height_visible_at_var
    

    This allows us to emit the final implications on tower placement. If a tower of height h is obscured, it is not visible: Bo4 → ¬B4v. This is not an equivalence, and we cannot treat Bo4 ≡ ¬B4v; maybe ¬B4v holds because B4 simply isn't placed there (but would be visible it it were!).

    for point in points:
        for k in range(1, n + 1):
            # Xok → ¬Xkv
            cnf.append(
                (
                    neg(is_height_obscured_at[(point, k)]),
                    neg(is_height_visible_at[(point, k)]),
                )
            )
    

    To relate this to the puzzle-specific visibility value, we just need to prohibit too many visible towers and ensure the desired count is visible at least one (and therefore exactly once):

    # At least one of the towers is the desired kth visible.
    cnf.append(tuple(is_kth_visible_tower_vars[desired]))
    
    # None of the towers can be visible above the desired kth.
    if desired < n:
        for is_kth_visible_tower_var in is_kth_visible_tower_vars[desired + 1]:
            cnf += basic_fact(neg(is_kth_visible_tower_var))
    
    return cnf
    

    We only need to block the first level of undesirable kth visible towers. Since the kth + 1 level will imply the existence of a kth level visible tower, it too is ruled out. (And so on.)

    Finally, we hook this into the CNF builder:

    # Set visible from left
    if self.visible_from_left:
        for index, row in enumerate(self.rows):
            target_visible = self.visible_from_left[index]
            if not target_visible:
                continue
    
            cnf += visible_in_span(row, target_visible)
    
    # Set visible from right
    if self.visible_from_right:
        for index, row in enumerate(self.rows):
            target_visible = self.visible_from_right[index]
            if not target_visible:
                continue
    
            cnf += visible_in_span(reversed(row), target_visible)
    
    # Set visible from top
    if self.visible_from_top:
        for index, col in enumerate(self.cols):
            target_visible = self.visible_from_top[index]
            if not target_visible:
                continue
    
            cnf += visible_in_span(col, target_visible)
    
    # Set visible from bottom
    if self.visible_from_bottom:
        for index, col in enumerate(self.cols):
            target_visible = self.visible_from_bottom[index]
            if not target_visible:
                continue
    
            cnf += visible_in_span(reversed(col), target_visible)
    

    The above gives me the 9x9 solution much more quickly:

    CNF: 0.028973951935768127s
    Solve: 0.07169117406010628s
    

    About 685x faster, and the solver is doing more of the overall work. Not bad for quick and dirty!

    There are many ways to clean this up. E.g., every place we see cnf.append((neg(a), b)) could be cnf += implies(a, b) instead for readability. We could avoid allocation of pointlessly large kth visible variables, and so on.

    This is not well-tested; I may have missed some implications or a rule. Hopefully it's easy to fix at this point.


    The last thing I want to touch on is the applicability of SAT. Perhaps painfully clear now, SAT solvers are not exactly great at counting and arithmetic. You have to lower to a circuit, hiding the higher-level semantics from the solving process.

    Other approaches will let you express arithmetic, intervals, sets, and so on naturally. Answer set programming (ASP) is one example of this, SMT solvers are another. For small problems SAT is fine, but for difficult problems these higher-level approaches can greatly simplify the problem.

    Each of those may actually internally decide to reason via SAT-solving (SMT in particular), but they will be doing so in the context of some higher-level understanding of the problem.


    ‡ This is the pure circuit approach to counting towers.

    Whether or not it is better than pairwise will depend on the number of towers being counted; maybe the constant factors are so high it's never useful, or maybe it's quite useful even at low sizes. I honestly have no idea — I've encoded huge circuits before and had them work great. It requires experimentation to know.

    I'm going to call Ah the integer height of the tower in location A. That is, rather than a one-hot encoding of either A1 or A2 or … or A9 we'll have Ah0, Ah1, …, and Ahn as the low through high bits of an n-bit integer (collectively Ah). For a limit of 9x9, 4 bits suffice. We'll also have Bh, Ch, and so on.

    You can join the two representations using A1 ≡ (¬Ah3 ∧ ¬Ah2 ∧ ¬Ah1 ∧ Ah0) and A2 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ ¬Ah0) and A3 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ Ah0) and so on. We have Ah = 3 if and only if A3 is set. (We don't need to add constraints that only one value of Ah is possible at a time, since the one-hot variables associated to each do this already.)

    With an integer in hand, it might be easier to see how to compute visibility. We can associate each location with a maximum seen tower height, named Am, Bm, and so on; obviously the first tower is always visible and the highest seen, so Am ≡ Ah. Again, this is actually an n-bit value Am0 through Amn.

    A tower is visible if and only if it's value is larger than the prior's highest seen. We'll track visibility with Av, Bv, and so on. This can be done with a digital comparator; so that Bv ≡ Bh > Am. (Av is a base case, and is simply always true.)

    This lets us fill in the rest of the max values as well. Bm ≡ Bv ? Bh : Am, and so on. A conditional/if-then-else/ite is a digital multiplexer. For a simple 2-to-1 this is straightforward: Bv ? Bh : Am is (Bv ∧ Bh) ∨ (¬Bv ∧ Am), which is really (Bv ∧ Bhi) ∨ (¬Bv ∧ Ami) for each i ∈ 0..n.

    Then, we'll have a bunch of single inputs Av through Iv that feed into an adder circuit, telling us how many of these inputs are true (i.e., how many towers are visible). This will be yet another n-bit value; then we use unit clauses to assert that it is exactly e.g. 3, if the particular puzzle demands 3 visible towers.

    We generate this same circuit for every span in every direction. This will be some polynomial-sized encoding of the rules, adding many extension variables and many clauses. A solver can learn a certain tower placement isn't viable not because we said so, but because it implies some unacceptable intermediate visibility. "There should be 4 visible, and 2 are already visible, so that leaves me with...".