Question in short
To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?
Question in detail
I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.
Example Tower puzzle (solved):
3 3 2 1
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
1 2 3 2
The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.
[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
...
('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]
This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).
What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?
This code is also available from this Github repository.
import string
import itertools
from sys import intern
from typing import Collection, Dict, List
from sat_utils import basic_fact, from_dnf, one_of, solve_one
Point = str
def comb(point: Point, value: int) -> str:
"""
Format a fact (a value assigned to a given point), and store it into the interned strings table
:param point: Point on the grid, characterized by two letters, e.g. AB
:param value: Value of the cell on that point, e.g. 2
:return: Fact string 'AB 2'
"""
return intern(f'{point} {value}')
def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
"""
Return how many towers are visible from the given line
>>> visible_from_line([1, 2, 3, 4])
4
>>> visible_from_line([1, 4, 3, 2])
2
"""
visible = 0
highest_seen = 0
for number in reversed(line) if reverse else line:
if number > highest_seen:
visible += 1
highest_seen = number
return visible
class TowersPuzzle:
def __init__(self):
self.visible_from_top = [3, 3, 2, 1]
self.visible_from_bottom = [1, 2, 3, 2]
self.visible_from_left = [3, 3, 2, 1]
self.visible_from_right = [1, 2, 3, 2]
self.given_numbers = {'AC': 3}
# self.visible_from_top = [3, 2, 1, 4, 2]
# self.visible_from_bottom = [2, 2, 4, 1, 2]
# self.visible_from_left = [3, 2, 3, 1, 3]
# self.visible_from_right = [2, 2, 1, 3, 2]
self._cnf = None
self._solution = None
def display_puzzle(self):
print('*** Puzzle ***')
self._display(self.given_numbers)
def display_solution(self):
print('*** Solution ***')
point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
self._display(point_to_value)
@property
def n(self) -> int:
"""
:return: Size of the grid
"""
return len(self.visible_from_top)
@property
def points(self) -> List[Point]:
return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]
@property
def rows(self) -> List[List[Point]]:
"""
:return: Points, grouped per row
"""
return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]
@property
def cols(self) -> List[List[Point]]:
"""
:return: Points, grouped per column
"""
return [self.points[i::self.n] for i in range(self.n)]
@property
def values(self) -> List[int]:
return list(range(1, self.n + 1))
@property
def cnf(self):
if self._cnf is None:
cnf = []
# Each point assigned exactly one value
for point in self.points:
cnf += one_of(comb(point, value) for value in self.values)
# Each value gets assigned to exactly one point in each row
for row in self.rows:
for value in self.values:
cnf += one_of(comb(point, value) for point in row)
# Each value gets assigned to exactly one point in each col
for col in self.cols:
for value in self.values:
cnf += one_of(comb(point, value) for point in col)
# Set visible from left
if self.visible_from_left:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_left[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(row, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from right
if self.visible_from_right:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_right[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm, reverse=True) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(row, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from top
if self.visible_from_top:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_top[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(col, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from bottom
if self.visible_from_bottom:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_bottom[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm, reverse=True) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(col, perm)
))
cnf += from_dnf(possible_perms)
# Set given numbers
for point, value in self.given_numbers.items():
cnf += basic_fact(comb(point, value))
self._cnf = cnf
return self._cnf
@property
def solution(self):
if self._solution is None:
self._solution = solve_one(self.cnf)
return self._solution
def _display(self, facts: Dict[Point, int]):
top_line = ' ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + ' '
print(top_line)
print('-' * len(top_line))
for index, row in enumerate(self.rows):
elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
[str(facts.get(point, ' ')) for point in row] + \
['|', str(self.visible_from_right[index]) or ' ']
print(' '.join(elems))
print('-' * len(top_line))
bottom_line = ' ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + ' '
print(bottom_line)
print()
if __name__ == '__main__':
puzzle = TowersPuzzle()
puzzle.display_puzzle()
puzzle.display_solution()
The actual time is spent in this helper function from the used helper code that came along with the video.
def from_dnf(groups) -> 'cnf':
'Convert from or-of-ands to and-of-ors'
cnf = {frozenset()}
for group_index, group in enumerate(groups, start=1):
print(f'Group {group_index}/{len(groups)}')
nl = {frozenset([literal]): neg(literal) for literal in group}
# The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
# The nl check skips over identities: {x, ~x, y} -> True
cnf = {clause | literal for literal in nl for clause in cnf
if nl[literal] not in clause}
# The sc check removes clauses with superfluous terms:
# {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
# Should this be left until the end?
sc = min(cnf, key=len) # XXX not deterministic
cnf -= {clause for clause in cnf if clause > sc}
return list(map(tuple, cnf))
The output from pyinstrument
when using a 4x4 grid shows that the line cnf = { ... }
in here is the culprit:
_ ._ __/__ _ _ _ _ _/_ Recorded: 21:05:58 Samples: 146
/_//_/// /_\ / //_// / //_'/ // Duration: 0.515 CPU time: 0.506
/ _/ v3.4.2
Program: ./src/towers.py
0.515 <module> ../<string>:1
[7 frames hidden] .., runpy
0.513 _run_code runpy.py:62
└─ 0.513 <module> towers.py:1
├─ 0.501 display_solution towers.py:64
│ └─ 0.501 solution towers.py:188
│ ├─ 0.408 cnf towers.py:101
│ │ ├─ 0.397 from_dnf sat_utils.py:65
│ │ │ ├─ 0.329 <setcomp> sat_utils.py:73
│ │ │ ├─ 0.029 [self]
│ │ │ ├─ 0.021 min ../<built-in>:0
│ │ │ │ [2 frames hidden] ..
│ │ │ └─ 0.016 <setcomp> sat_utils.py:79
│ │ └─ 0.009 [self]
│ └─ 0.093 solve_one sat_utils.py:53
│ └─ 0.091 itersolve sat_utils.py:43
│ ├─ 0.064 translate sat_utils.py:32
│ │ ├─ 0.049 <listcomp> sat_utils.py:39
│ │ │ ├─ 0.028 [self]
│ │ │ └─ 0.021 <listcomp> sat_utils.py:39
│ │ └─ 0.015 make_translate sat_utils.py:12
│ └─ 0.024 itersolve ../<built-in>:0
│ [2 frames hidden] ..
└─ 0.009 <module> typing.py:1
[26 frames hidden] typing, abc, ..
First, it's good to note the difference between equivalence and equisatisfiability. In general, converting an arbitrary boolean formula (say, something in DNF) to CNF can result in a exponential blow-up in size.
This blow-up is the issue with your from_dnf
approach: whenever you handle another product term, each of the literals in that product demands a new copy of the current cnf clause set (to which it will add itself in every clause). If you have n product terms of size k, the growth is O(k^n)
.
In your case n
is actually a function of k!
. What's kept as a product term is filtered to those satisfying the view constraint, but overall the runtime of your program is roughly in the region of O(k^f(k!))
. Even if f grows logarithmically, this is still O(k^(k lg k))
and not quite ideal!
Because you're asking "is this satisfiable?", you don't need an equivalent formula but merely an equisatisfiable one. This is some new formula that is satisfiable if and only if the original is, but which might not be satisfied by the same assignments.
For example, (a ∨ b)
and (a ∨ c) ∧ (¬b)
are each obviously satisfiable, so they are equisatisfiable. But setting b
true satisfies the first and falsifies the second, so they are not equivalent. Furthermore the first doesn't even have c
as a variable, again making it not equivalent to the second.
This relaxation is enough to replace this exponential blow-up with a linear-sized translation instead.
The critical idea is the use of extension variables. These are fresh variables (i.e., not already present in the formula) that allow us to abbreviate expressions, so we don't end up making multiple copies of them in the translation. Since the new variable is not present in the original, we'll no longer have an equivalent formula; but because the variable will be true if and only if the expression is, it will be equisatisfiable.
If we wanted to use x
as an abbreviation of y
, we'd state x ≡ y
. This is the same as x → y
and y → x
, which is the same as (¬x ∨ y) ∧ (¬y ∨ x)
, which is already in CNF.
Consider the abbreviation for a product term: x ≡ (a ∧ b)
. This is x → (a ∧ b)
and (a ∧ b) → x
, which works out to be three clauses: (¬x ∨ a) ∧ (¬x ∨ b) ∧ (¬a ∨ ¬b ∨ x)
. In general, abbreviating a product term of k literals with x
will produce k binary clauses expressing that x
implies each of them, and one (k+1)
-clause expressing that all together they imply x
. This is linear in k
.
To really see why this helps, try converting (a ∧ b ∧ c) ∨ (d ∧ e ∧ f) ∨ (g ∧ h ∧ i)
to an equivalent CNF with and without an extension variable for the first product term. Of course, we won't just stop with one term: if we abbreviate each term then the result is precisely a single CNF clause: (x ∨ y ∨ z)
where these each abbreviate a single product term. This is a lot smaller!
This approach can be used to turn any circuit into an equisatisfiable formula, linear in size and in CNF. This is called a Tseitin transformation. Your DNF formula is simply a circuit composed of a bunch of arbitrary fan-in AND gates, all feeding into a single arbitrary fan-in OR gate.
Best of all, although this formula is not equivalent due to additional variables, we can recover an assignment for the original formula by simply dropping the extension variables. It is sort of a 'best case' equisatisfiable formula, being a strict superset of the original.
To patch this into your code, I added:
# Uses pseudo-namespacing to avoid collisions.
_EXT_SUFFIX = "___"
_NEXT_EXT_INDEX = 0
def is_ext_var(element) -> bool:
return element.endswith(_EXT_SUFFIX)
def ext_var() -> str:
global _NEXT_EXT_INDEX
ext_index = _NEXT_EXT_INDEX
_NEXT_EXT_INDEX += 1
return intern(f"{ext_index}{_EXT_SUFFIX}")
This lets us pull a fresh named variable out of thin air. Since these extension variable names don't have meaningful semantics to your solution display function, I changed:
point_to_value = {
point: value for point, value in [fact.split() for fact in self.solution]
}
into:
point_to_value = {
point: value
for point, value in [
fact.split() for fact in self.solution if not is_ext_var(fact)
]
}
There are certainly better ways to do this, this is just a patch. :)
Reimplementing your from_dnf
with the above ideas, we get:
def from_dnf(groups) -> "cnf":
"Convert from or-of-ands to and-of-ors, equisatisfiably"
cnf = []
extension_vars = []
for group in groups:
extension_var = ext_var()
neg_extension_var = neg(extension_var)
imply_ext_clause = []
for literal in group:
imply_ext_clause.append(neg(literal))
cnf.append((neg_extension_var, literal))
imply_ext_clause.append(extension_var)
cnf.append(tuple(imply_ext_clause))
extension_vars.append(extension_var)
cnf.append(tuple(extension_vars))
return cnf
Each group gets an extension variable. Each literal in the group adds its negation into the (k+1)
-sized implication clause, and becomes implied by the extension. After the literals are handled, the extension variable finalizes the remaining implication and adds itself to the list of new extension variables. Finally, at least one of these extension variables must be true.
This change alone lets me solve this 5x5 puzzle ~instantly:
self.visible_from_top = [3, 2, 1, 4, 2]
self.visible_from_bottom = [2, 2, 4, 1, 2]
self.visible_from_left = [3, 2, 3, 1, 3]
self.visible_from_right = [2, 2, 1, 3, 2]
self.given_numbers = {}
I added some timing output as well:
@property
def solution(self):
if self._solution is None:
start_time = time.perf_counter()
cnf = self.cnf
cnf_time = time.perf_counter()
print(f"CNF: {cnf_time - start_time}s")
self._solution = solve_one(cnf)
end_time = time.perf_counter()
print(f"Solve: {end_time - cnf_time}s")
return self._solution
The 5x5 puzzle gives me:
CNF: 0.00565183162689209s
Solve: 0.005589433014392853s
However, we still have that pesky k!
growth when enumerating viable tower height permutations.
I generated a 9x9 puzzle (the largest the site permits), which corresponds to:
self.visible_from_top = [3, 3, 3, 3, 1, 4, 2, 4, 2]
self.visible_from_bottom = [3, 1, 4, 2, 5, 3, 3, 2, 3]
self.visible_from_left = [3, 3, 1, 2, 4, 5, 2, 3, 2]
self.visible_from_right = [3, 1, 7, 4, 3, 3, 2, 2, 4]
self.given_numbers = {
"AB": 5,
"AD": 4,
"BD": 3,
"BE": 2,
"CD": 7,
"CF": 5,
"CG": 1,
"DB": 1,
"DH": 7,
"EA": 4,
"EI": 2,
"FA": 2,
"FE": 8,
"GG": 7,
"GI": 6,
"HA": 3,
"HF": 2,
"HH": 1,
"IG": 6,
}
This gives me:
CNF: 28.505195066332817s
Solve: 40.48229135945439s
We should spend more time solving and less time generating, but close to half the time is generating.
In my opinion, using DNF in a CNF-SAT translation is often† a sign of the wrong approach. Solvers are way better at exploring and learning about the solution space than we are — spending factorial amount of time pre-exploring is actually worse than the solver's exponential worse case.
It's understandable to 'fall back' to DNF, because programmers naturally think in terms of "write an algorithm that emits solutions". But the real benefit of solvers kicks in when you encode this in the problem. Let the solver reason about conditions in which solutions become infeasible. To do this, we want to think in terms of circuits. Lucky for us, we also know how to turn a circuit into CNF quickly.
†I said "often"; if your DNF is small and quick to produce (like a single circuit gate), or if encoding it to a circuit is prohibitively complicated, then it can be helpful to pre-compute some of the solution space.
You've actually already done some of this! For example, we will need a circuit that counts how many times a certain number appears in a span (row or column), and an assertion that this number is exactly one. Then for each span and for each number, we'll emit this circuit. That way if a tower of size e.g. 3 appears twice in a row, the counter for that row for 3 will emit '2' and our assertion that it be '1' will not be upheld.
Your one_of
constraint is one possible implementation of this. Yours uses the 'obvious' pairwise encoding: for each location in the span, if N is present at that location then it is not present in any other location. This is actually quite a good encoding because it's comprised almost entirely of binary clauses, and SAT solvers love binary clauses (they use significantly less memory and propagate often). But for really large sets of things to count, this O(n^2)
scaling can become an issue.
You can imagine an alternative approach where you literally encode an adder circuit: each location is an input bit to the circuit, and the circuit produces n bits of output telling you the final sum (the paper above is a good read!). You then assert this sum is exactly one using unit clauses that force specific output bits.
It may seem redundant to encode a circuit only to force some of its outputs to be a constant value. However, this is much easier to reason about and modern solvers are aware that encodings do this and optimize for it. They perform significantly more sophisticated in-processing than the initial encoding process could reasonably do. The 'art' of using solvers is in knowing and testing when these alternative encodings work better than others.
Note that exactly_k_of
is at_least_k_of
along with at_most_k_of
. You've noted this in your Q
class ==
implementation. Implementing at_least_1_of
is trivial, being one clause; at_most_1_of
is so common it's often just called AMO
. I encourage you to try implementing <
and >
in some of the other ways discuss in the paper (perhaps even choosing which to use based on input size) to get a feel for it.
Turning our attention back to the k!
visibility constraints, what we need is a circuit that tells us how many towers are visible from a certain direction, which we can then assert be a specific value.
Stop and think about how this could be done, it's not easy!
Analogous to the various one_of
approaches, we can go with a 'pure' circuit for counting or use a simpler but worse-scaling pairwise approach. I have attached the sketch of the pure circuit approach at the very bottom (‡) of this answer. For now we will use the pairwise method.
The main observation to make is that among non-visible towers, we don't care about their permutations. Consider:
3 -> 1 5 _ _ _ 9 _ _ _
A B C D E F G H I
We see 3 towers from the left as long as the CDE group contains 2
, 3
, and 4
, and likewise if the GHI group contains 6
, 7
, and 8
. But the order in which they appear in the group has no implication on the visible towers.
Rather than compute which towers are visible, we will declare which towers are visible and follow their implications. We'll be filling in this function:
def visible_in_span(points: Collection[str], desired: int) -> "cnf":
"""Assert desired visible towers in span. Wlog, visibility is from index 0."""
points = list(points)
n = len(points)
assert desired <= n
cnf = []
# ...
return cnf
Assume a fixed span and viewing direction: each location will have k associated variables, Av1
through Avk
stating "this is the kth visible tower". We will also have Av ≡ (Av1 ∨ Av2 ∨ ⋯ ∨ Avk)
meaning "A has a visible tower".
In the above example, Av1
, Bv2
, and Fv3
are all true. There are some obvious implications to emit. At a location, at most one of these is true (you can't be both the first and second visible tower) — but not exactly one, since it's perfectly fine to have a non-visible tower. Another is that if a location is the kth visible tower, then no other location is also the kth visible tower.
We can add this so far:
is_kth_visible_tower_at = {}
is_kth_visible_tower_vars = collections.defaultdict(list)
is_visible_tower_at = {}
for point in points:
is_visible_tower_vars = []
for k in range(1, n + 1):
# Xvk
is_kth_visible_tower_var = ext_var()
is_kth_visible_tower_at[(point, k)] = is_kth_visible_tower_var
is_kth_visible_tower_vars[k].append(is_kth_visible_tower_var)
is_visible_tower_vars.append(is_kth_visible_tower_var)
# Xv
is_visible_tower_at_var = ext_var()
# Xv → (Xv1 ∨ Xv2 ∨ ⋯)
cnf.append(tuple([neg(is_visible_tower_at_var)] + is_visible_tower_vars))
# (Xv1 ∨ Xv2 ∨ ⋯) → Xv
for is_visible_tower_var in is_visible_tower_vars:
cnf.append((neg(is_visible_tower_var), is_visible_tower_at_var))
is_visible_tower_at[point] = is_visible_tower_at_var
# At most one visible tower here.
cnf += Q(is_visible_tower_vars) <= 1
# At most one kth visible tower anywhere.
for k in range(1, n + 1):
cnf += Q(is_kth_visible_tower_vars[k]) <= 1
Next we need ordering among visible towers, so that the kth + 1 visible tower comes after the kth visible tower. This is accomplished by the kth + 1 visible tower forcing at least one of the prior locations to be the kth visible tower. E.g., Dv3 → (Av2 ∨ Bv2 ∨ Cv2)
and Cv2 → (Av1 ∨ Bv1)
. We know Av1
is always true which provides the base case. (If we enter a situation like needing B to be the third visible tower, that will require A be the second visible tower which contradicts Av1
.)
# Towers are ordered.
for index, point in enumerate(points):
if index == 0:
cnf += basic_fact(is_kth_visible_tower_at[(point, 1)])
continue
for k in range(1, n + 1):
# Xvk → ⋯
implication = [neg(is_kth_visible_tower_at[(point, k)])]
j = k - 1
if j > 0:
for index_j, point_j in enumerate(points):
if index_j == index:
break
# ⋯ ∨ Wxj ∨ ⋯
implication.append(is_kth_visible_tower_at[(point_j, j)])
cnf.append(tuple(implication))
So far so good, but we haven't related tower height to visibility. The above would allow 9 8 7
as a solution, calling 9
the first visible tower, 8
the second, and 7
the third. To solve this we want a tower placement to prohibit a smaller tower from also being visible.
Each location will again receive a set of abbreviations indicating if it obscured below a certain height, called Ao1
, Ao2
, and so on. This will give us a 'grid' of implications that keep things simpler. The first is that a higher tower being obscured implies the next highest tower at the same location is also obscured, so that Ao3 → Ao2
and Ao2 → Ao1
. The second is that if a tower is obscured at one location, it is also obscured at all later locations. This is Ao3 → Bo3
and Bo3 → Co3
and so on.
is_height_obscured_at = {}
is_height_obscured_previous = [None] * n
for point in points:
is_obscured_previous = None
for k in range(1, n + 1):
# Xok
is_height_obscured_var = ext_var()
# Wok → Xok
is_k_obscured_previous = is_height_obscured_previous[k - 1]
if is_k_obscured_previous is not None:
cnf.append((neg(is_k_obscured_previous), is_height_obscured_var))
# Xok → Xo(k-1)
if is_obscured_previous is not None:
cnf.append((neg(is_height_obscured_var), is_obscured_previous))
is_height_obscured_at[(point, k)] = is_height_obscured_var
is_height_obscured_previous[k - 1] = is_height_obscured_var
is_obscured_previous = is_height_obscured_var
From this it's easy to see that stating e.g. Bo4
implies the remaining towers equal or less than 4 in height are all obscured. We can now easily relate tower placement to obscurity: A5 → Bo4
.
# A placed tower obscures smaller later towers.
for index, point in enumerate(points):
if index + 1 == len(points):
break
next_point = points[index + 1]
for k in range(2, n + 1):
j = k - 1
# Xk → Yo(k-1)
cnf.append((neg(comb(point, k)), is_height_obscured_at[(next_point, j)]))
Last, we need to relate obscurity to visibility. We'll need one final set of abbreviations, stating that a specific tower height is visible at a location. At the risk of making typos easy, we'll call this Ahv
for some height h
, so that Ahv ≡ (Ah ∧ Av)
. A concrete example would be C3v ≡ (C3 ∧ Cv)
: a tower of height 3 is visible at C if and only if there is a tower visible at C, and that tower is the height 3 tower.
is_height_visible_at = {}
for point in points:
for k in range(1, n + 1):
# Xhv
height_visible_at_var = ext_var()
# Xhv ≡ (Xh ∧ Xv)
cnf.append((neg(height_visible_at_var), comb(point, k)))
cnf.append((neg(height_visible_at_var), is_visible_tower_at[point]))
cnf.append(
(
neg(comb(point, k)),
neg(is_visible_tower_at[point]),
height_visible_at_var,
)
)
is_height_visible_at[(point, k)] = height_visible_at_var
This allows us to emit the final implications on tower placement. If a tower of height h is obscured, it is not visible: Bo4 → ¬B4v
. This is not an equivalence, and we cannot treat Bo4 ≡ ¬B4v
; maybe ¬B4v
holds because B4
simply isn't placed there (but would be visible it it were!).
for point in points:
for k in range(1, n + 1):
# Xok → ¬Xkv
cnf.append(
(
neg(is_height_obscured_at[(point, k)]),
neg(is_height_visible_at[(point, k)]),
)
)
To relate this to the puzzle-specific visibility value, we just need to prohibit too many visible towers and ensure the desired count is visible at least one (and therefore exactly once):
# At least one of the towers is the desired kth visible.
cnf.append(tuple(is_kth_visible_tower_vars[desired]))
# None of the towers can be visible above the desired kth.
if desired < n:
for is_kth_visible_tower_var in is_kth_visible_tower_vars[desired + 1]:
cnf += basic_fact(neg(is_kth_visible_tower_var))
return cnf
We only need to block the first level of undesirable kth visible towers. Since the kth + 1 level will imply the existence of a kth level visible tower, it too is ruled out. (And so on.)
Finally, we hook this into the CNF builder:
# Set visible from left
if self.visible_from_left:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_left[index]
if not target_visible:
continue
cnf += visible_in_span(row, target_visible)
# Set visible from right
if self.visible_from_right:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_right[index]
if not target_visible:
continue
cnf += visible_in_span(reversed(row), target_visible)
# Set visible from top
if self.visible_from_top:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_top[index]
if not target_visible:
continue
cnf += visible_in_span(col, target_visible)
# Set visible from bottom
if self.visible_from_bottom:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_bottom[index]
if not target_visible:
continue
cnf += visible_in_span(reversed(col), target_visible)
The above gives me the 9x9 solution much more quickly:
CNF: 0.028973951935768127s
Solve: 0.07169117406010628s
About 685x faster, and the solver is doing more of the overall work. Not bad for quick and dirty!
There are many ways to clean this up. E.g., every place we see cnf.append((neg(a), b))
could be cnf += implies(a, b)
instead for readability. We could avoid allocation of pointlessly large kth visible variables, and so on.
This is not well-tested; I may have missed some implications or a rule. Hopefully it's easy to fix at this point.
The last thing I want to touch on is the applicability of SAT. Perhaps painfully clear now, SAT solvers are not exactly great at counting and arithmetic. You have to lower to a circuit, hiding the higher-level semantics from the solving process.
Other approaches will let you express arithmetic, intervals, sets, and so on naturally. Answer set programming (ASP) is one example of this, SMT solvers are another. For small problems SAT is fine, but for difficult problems these higher-level approaches can greatly simplify the problem.
Each of those may actually internally decide to reason via SAT-solving (SMT in particular), but they will be doing so in the context of some higher-level understanding of the problem.
‡ This is the pure circuit approach to counting towers.
Whether or not it is better than pairwise will depend on the number of towers being counted; maybe the constant factors are so high it's never useful, or maybe it's quite useful even at low sizes. I honestly have no idea — I've encoded huge circuits before and had them work great. It requires experimentation to know.
I'm going to call Ah
the integer height of the tower in location A. That is, rather than a one-hot encoding of either A1
or A2
or … or A9
we'll have Ah0
, Ah1
, …, and Ahn
as the low through high bits of an n-bit integer (collectively Ah
). For a limit of 9x9, 4 bits suffice. We'll also have Bh
, Ch
, and so on.
You can join the two representations using A1 ≡ (¬Ah3 ∧ ¬Ah2 ∧ ¬Ah1 ∧ Ah0)
and A2 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ ¬Ah0)
and A3 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ Ah0)
and so on. We have Ah = 3
if and only if A3
is set. (We don't need to add constraints that only one value of Ah
is possible at a time, since the one-hot variables associated to each do this already.)
With an integer in hand, it might be easier to see how to compute visibility. We can associate each location with a maximum seen tower height, named Am
, Bm
, and so on; obviously the first tower is always visible and the highest seen, so Am ≡ Ah
. Again, this is actually an n-bit value Am0
through Amn
.
A tower is visible if and only if it's value is larger than the prior's highest seen. We'll track visibility with Av
, Bv
, and so on. This can be done with a digital comparator; so that Bv ≡ Bh > Am
. (Av
is a base case, and is simply always true.)
This lets us fill in the rest of the max values as well. Bm ≡ Bv ? Bh : Am
, and so on. A conditional/if-then-else/ite is a digital multiplexer. For a simple 2-to-1 this is straightforward: Bv ? Bh : Am
is (Bv ∧ Bh) ∨ (¬Bv ∧ Am)
, which is really (Bv ∧ Bhi) ∨ (¬Bv ∧ Ami)
for each i ∈ 0..n
.
Then, we'll have a bunch of single inputs Av
through Iv
that feed into an adder circuit, telling us how many of these inputs are true (i.e., how many towers are visible). This will be yet another n-bit value; then we use unit clauses to assert that it is exactly e.g. 3
, if the particular puzzle demands 3 visible towers.
We generate this same circuit for every span in every direction. This will be some polynomial-sized encoding of the rules, adding many extension variables and many clauses. A solver can learn a certain tower placement isn't viable not because we said so, but because it implies some unacceptable intermediate visibility. "There should be 4 visible, and 2 are already visible, so that leaves me with...".