I want to implement Minimax AI for Tic Tac Toe of orders 3, 4, 5.
Rules of the game:
For an order n Tic Tac Toe game, there is a board comprising of n rows and n columns, for a total of n2 cells. The board is initially empty.
There are two players, players move alternatively, no player can abstain from moving or move in two consecutive turns. In each turn a player must choose a cell that hasn't been previously chosen.
Games ends when either player has occupied a complete row, column or diagonal of n cells, or all cells are occupied.
There are 3 states a cell can be in, so an naive upper bound of count of states can be calculated using 3n2, disregarding the rules. For order 3 it is 19,683, for 4 43,046,721 and for 5 it is 847,288,609,443.
I have programmatically enumerated all legal states for orders 3 and 4. For order 3, there are 5,478 states reachable if "O" moves first, and 5,478 if "X" moves first, rotations and reflections are counted as distinct boards, for a total of 8,533 unique states reachable. For order 4, 972,2011 states are reachable if either player moves first, for a total of 14,782,023 states.
I don't have the exact number of states for order 5, but based on the fact players move alternatively and ignoring the game over condition, there are 161,995,031,226 states reachable if the game doesn't end when there is a winner. So the number of legal states is less than that, I estimate the error of my calculation is within 10%.
I have previously implemented a working reinforcement learning AI for Tic Tac Toe order 3, but wasn't satisfied by its performance.
So I have tried to implement Minimax AI for Tic Tac Toe, the only thing relevant I have found is this, but the code quality is horrible and doesn't actually work.
So I tried to implement my own version based on it.
Because a player has either occupied a cell or not occupied a cell, this is binary, so for one player the state of the board can be represented by n2 bits, as there are two players we need 2n2 bits for the for information about the board.
1 indicates the player has occupied a cell, 0 indicates a cell is not occupied by the player. Denote the integers that encode player information as (o, x)
, o
and x
cannot has common set bits, to get the full information of the board, use full = o | x
, so a set bit in full
means the corresponding cell is occupied by either player, else the cell isn't occupied.
I pack the board into one integer using o << n * n | x
to store the information as efficiently as possible, even with this even storing the information of the encoded states for order 4 takes more than 1GiB RAM. Storing the boards and corresponding legal moves for order 4 takes more than 7GiB RAM (I have 16GiB RAM). The board is unpacked by o = full >> n * n; x = full & (1 << n * n) - 1
.
Counting from left to right, and from top to bottom, cell located at row r
column c
corresponds to full & 1 << ((n - r) * n - 1 - c)
. A move is set by bit-wise OR |
.
A fully occupied board has no unset bits, therefore full.bit_count() == n * n
. A board has n rows and n columns and 2 diagonals, winner is determined by generating the bit masks for all 2n+2 lines and iterating through the masks to find if any o & mask == mask
or x & mask == mask
.
And finally, when there is exactly one gap in a line and all other cells in the same line are occupied by one player, said player can win in the next move. So a rational player should always choose such a gap when it is their turn, to win the game if they occupied the other cells or to prevent the other player from winning. The AI should only consider such gaps when there are such gaps.
Thus, the winning strategy would be to create at least two such gaps simultaneously, thus the opponent can only block one gap allowing the player to win in the next turn, this requires at least 2n - 3 cells to be occupied by the player, assuming the player moves first the other player would have taken 2n - 4 turns, the total number of turns so far would thus be 4n - 7. The next move would be the opponent's, so the AI should seek to win in 4n - 5 turns if it moves first, or 4n - 4 turns if it moves second.
The following is the code I used to enumerate all Tic Tac Toe legal states for order 3 (and order 4, the code for order 4 is omitted for brevity, but it can be obtained by trivially changing some numbers):
from typing import List, Tuple
def pack(line: range, last: int) -> int:
return (sum(1 << last - i for i in line), tuple(line))
def generate_lines(n: int) -> List[Tuple[int, Tuple[int]]]:
square = n * n
last = square - 1
lines = []
for i in range(n):
lines.extend(
(
pack(range(i * n, i * n + n), last),
pack(range(i, square, n), last),
)
)
lines.extend(
(
pack(range(0, square, n + 1), last),
pack(range((m := n - 1), n * m + 1, m), last),
)
)
return lines
LINES_3 = generate_lines(3)
FULL3 = (1 << 9) - 1
GAMESTATES_3_P1 = {}
GAMESTATES_3_P2 = {}
def check_state_3(o: int, x: int) -> Tuple[bool, int]:
for line, _ in LINES_3:
if o & line == line:
return True, 0
elif x & line == line:
return True, 1
return (o | x).bit_count() == 9, 2
def process_states_3(board: int, move: bool, states: dict, moves: List[int]) -> None:
if board not in states:
o = board >> 9
x = board & FULL3
if not check_state_3(o, x)[0]:
left = 8 + 9 * move
for i, n in enumerate(moves):
process_states_3(
board | 1 << left - n, not move, states, moves[:i] + moves[i + 1 :]
)
c = len(moves)
states[board] = {i: 1 << c for i in moves}
process_states_3(0, 1, GAMESTATES_3_P1, list(range(9)))
process_states_3(0, 0, GAMESTATES_3_P2, list(range(9)))
The following is my reimplementation of the code found in the linked article.
MINIMAX_STATES_3_P1 = {}
SCORES_3 = (10, -10, 0)
def minimax_search_3(board: int, states: dict, maximize: bool, moves: List[int]) -> int:
if score := states.get(board):
return score
o = board >> 9
x = board & FULL3
over, winner = check_state_3(o, x)
if over:
score = SCORES_3[winner]
states[board] = score
return score
left = 8 + 9 * maximize
best, extreme, maximize = (-1e309, max, False) if maximize else (1e309, min, True)
for i, n in enumerate(moves):
best = extreme(
best,
minimax_search_3(
board | 1 << left - n, states, maximize, moves[:i] + moves[i + 1 :]
),
)
states[board] = best
return best
minimax_search_3(0, MINIMAX_STATES_3_P1, 1, list(range(9)))
It doesn't work at all, all scores are either 10, 0 or -10, it doesn't take recursion depth into account. The function found in the article will even repeatedly evaluate the same states over and over again, because the states can be reached in different ways, and the function does redundant calculations, I fixed that by caching.
The AI should at minimum stop recursion when a given number of turns are reached, and wins that occur much later should have less weight, and the score of states should vary based on how many win states they lead. And as mentioned before, when there are gaps the AI should only consider such gaps.
I have tried to fix the problems myself, I wrote the following Minimax-ish function, but I don't actually know Minimax theory and I don't know if it works:
def generate_gaps(lines: List[Tuple[int, Tuple[int]]], l: int):
k = l * l - 1
return [
(sum(1 << k - n for n in line[:i] + line[i + 1 :]), 1 << k - line[i], line[i])
for _, line in lines
for i in range(l)
]
GAPS_3 = generate_gaps(LINES_3, 3)
def find_gaps_3(board: int, player: int) -> int:
return [i for mask, pos, i in GAPS_3 if player & mask == mask and not board & pos]
MINIMAX_3 = {}
def my_minimax_search_3(
board: int, states: dict, maximize: bool, moves: List[int], turns: int, depth: int
) -> int:
if entry := states.get(board):
return entry["score"]
o = board >> 9
x = board & FULL3
over, winner = check_state_3(o, x)
if over:
score = SCORES_3[winner] * 1 << depth
states[board] = {"score": score}
return score
if (full := o | x).bit_count() > turns:
return 0
depth -= 1
left, new = (17, False) if maximize else (8, True)
gaps = set(find_gaps_3(full, o) + find_gaps_3(full, x))
weights = {
n: my_minimax_search_3(
board | 1 << left - n, states, new, moves[:i] + moves[i + 1 :], depth
)
for i, n in enumerate(moves)
if not gaps or n in gaps
}
score = [-1, 1][maximize] * sum(weights.values())
states[board] = {"weights": weights, "score": score}
return score
my_minimax_search_3(0, MINIMAX_3, 1, list(range(9)), 9, 9)
How should I properly implement Minimax for Tic Tac Toe?
The number of states is not that relevant. Minimax (with depth) will just have to look at the states that are reachable with that depth.
What you need is a good enough evaluation function in case the depth has been reached without a winner. Just returning 0 is not a very informed evaluation function.
A candidate as evaluation function could be based on the number of lines where a player could still win, i.e. those lines where the opponent has no occupation yet. For such lines you could give credit for the number of squares that the player has already occupied in such a line: the more the better. I would make this quadratic, so that just one occupied square counts as 1, but 2 counts as 4, and 3 as 9, ...
With such an evaluation function, it turns out that a depth of 2 is already enough for the AI to never lose a game. I must add that against a bad player it would be much harder to win on a 5x5 board than on a 3x3 board. On a 5x5 board it is very easy to block every possibility for the opponent to win. But all these games are a tie with best play.
Here is an implementation of minimax (implemented with the principle of negamax) with alpha-beta pruning. There is no need to memoize boards, as the depth is shallow.
from random import choice
INF = float("inf")
class TicTacToe:
def __init__(self, size=3):
self.players = [0, 0] # bits are 1 when occupied by player
self.turn = 0
self.size = size
self.wins = []
self.full = (1 << (size * size)) - 1
# set masks for horizontal wins
mask = (1 << size) - 1
for i in range(size):
self.wins.append(mask)
mask <<= size
# set masks for vertical wins
mask = 0
for i in range(size):
mask = (mask << size) | 1
for i in range(size):
self.wins.append(mask)
mask <<= 1
# set masks for diagonals
mask = 0
for i in range(size):
mask = (mask << (size + 1)) | 1
self.wins.append(mask)
mask = 0
for i in range(size):
mask = (mask << (size - 1)) | 1
self.wins.append(mask << (size - 1))
def move(self, square):
mask = 1 << square
self.players[self.turn] ^= mask
self.turn = 1 - self.turn
def undo(self, square):
mask = 1 << square
self.turn = 1 - self.turn
self.players[self.turn] ^= mask
def moves(self):
board = self.players[0] | self.players[1]
for square in range(self.size ** 2):
if board & (1 << square) == 0:
yield square
def is_won(self):
x = self.players[1 - self.turn]
return any(x & mask == mask for mask in self.wins)
def is_full(self):
return self.players[0] | self.players[1] == self.full
def is_over(self):
return self.is_full() or self.is_won()
def threats(self, turn): # for use by heuristic evaluation
player = self.players[turn]
opponent = self.players[1 - turn]
return sum(
(player & mask).bit_count() ** 2
for mask in self.wins
if opponent & mask == 0
)
def heuristic(self):
return self.threats(self.turn) - self.threats(1 - self.turn)
def minimax(self, depth, alpha, beta):
if self.is_won():
return None, -INF
if self.is_full():
return None, 0
if depth == 0:
return None, self.heuristic()
best_moves = None
best_score = -INF
for square in self.moves():
self.move(square)
score = -self.minimax(depth - 1, -beta, -alpha)[1]
self.undo(square)
if score == best_score:
best_moves.append(square) # Collect equally valued moves
elif score > best_score:
best_score = score
best_moves = [square]
if score > beta:
break
alpha = max(alpha, score)
return choice(best_moves), best_score
def best_move(self, depth=2):
return self.minimax(depth, -INF, INF)[0]
def random_move(self): # To allow testing with a stupid opponent
return choice(list(self.moves()))
def __repr__(self):
s = ""
for i in range(self.size * self.size):
mask = 1 << i
s += " " if i % self.size else "\n"
s += ".XO"[(self.players[0] & mask > 0) + (self.players[1] & mask > 0) * 2]
return s.strip()
# Play a 4x4 game between AI and random mover
game = TicTacToe(4)
while not game.is_over():
move = game.random_move() if game.turn else game.best_move()
game.move(move)
print(game)
print()