haskellrecursionmonadsbacktrackingstate-monad

How to solve Alphametics puzzle using the State Monad and mutable Vector?


I'm working on the Alphametics puzzle

A set of words is written down in the form of an ordinary "long-hand" addition sum, and it is required that the letters of the alphabet be replaced with decimal digits so that the result is a valid arithmetic sum.Example:

SEND
MORE
-----
MONEY

This equation has a unique solution:

9567
1085
-----
10652

A non brute force solution is to use backtracking with memoization. My choice is to use the State Monad along with mutable Vectors.

The algorithm goes as follows:

If we are beyond the leftmost digit of the sum:
  Return true if no carry, false otherwise.
  Also check that there is no leading zero in the sum.
Else if addend and current column index is beyond the current row:
  Recur on row beneath this one.

If we are currently trying to assign a char in one of the addends:
  If char already assigned, recur on row beneath this one.
  If not assigned, then:
    For every possible choice among the digits not in use:
      Make that choice and recur on row beneath this one.
        If successful, return true.
        Else, unmake assignment and try another digit.
    Return false if no assignment worked to trigger backtracking.

Else if trying to assign a char in the sum:
  If char already assigned:
    If matches the sum digit, recur on next column to the left with carry.
    Else, return false to trigger backtracking.
  If char unassigned:
    If correct digit already used, return false.
    Else:
      Assign it and recur on next column to the left with carry:
        If successful return true.
        Else, unmake assignment, and return false to trigger backtracking.

I'm having trouble with writing the part where a number is assigned to an addend.

Rust code for reference that needs to be translated to Haskell.

let used: HashSet<&u8> = HashSet::from_iter(solution.values());
let unused: Vec<u8> = (0..=9).filter(|x| !used.contains(x)).collect();
for i in unused {
    if i == 0 && non_zero_letters.contains(&letter) {
        continue;
    }
    solution.insert(letter, i);
    if can_solve(
        equation,
        result,
        non_zero_letters,
        row + 1,
        col,
        carry + (i as u32),
        solution,
    ) {
        return true;
    }
    solution.remove(&letter);
}
false

My code, that I've yet to compile, and without the above case implemented, is shown below:

equation contains the addend rows. result is the sum row. solution is the assignments. nonZeroLetters is an optimization that checks there are no leading zeros in any of the rows.

solve :: String -> Maybe [(Char, Int)]
solve puzzle = error "You need to implement this function."

type Solution = Vector Int

type Row = Vector Char

data PuzzleState = PuzzleState
  { equation :: Vector Row,
    result :: Row,
    nonZeroLetters :: Set Char,
    solution :: MVector Row
  }

canSolve :: Int -> Int -> Int -> State PuzzleState Bool
canSolve row col carry = do
  PuzzleState {equation, result, nonZeroLetters, solution} <- get

  let addend = row < length equation
  let word = if addend then (equation ! row) else result
  let n = length word
  let letter = word ! col

  let ord x = C.ord x - C.ord 'A'
  let readC = UM.read (solution . ord)

  i <- readC letter
  let assigned = i >= 0

  let isNonZero = flip S.member nonZeroLetters

  case () of
    _
      | col >= n && addend -> canSolve (row + 1) col carry
      | col == n && (not . addend) -> carry == 0
      | addend && assigned -> canSolve (row + 1) col (carry + i)

ord :: Char -> Int
ord x = C.ord x - C.ord 'A'

readC ::
  (PrimMonad m, UM.Unbox a) =>
  MV.MVector (PrimState m) a ->
  Char ->
  m a
readC solution c = UM.read solution $ ord c

writeC ::
  (PrimMonad m, UM.Unbox a) =>
  UM.MVector (PrimState m) a ->
  Char ->
  a ->
  m ()
writeC solution c x = UM.write solution $ ord c $ x

Here's the (invalid and incomplete) draft that I need help with. This is the part for which I showed Rust code above.

| addend -> let used <- M.mapM (0 <= UM.read solution) [0..length solution - 1]
                unused = filter (\x -> x == 0 && isNonZero x) [0..9] \\ used
                  in do
                    i <- unused
                    writeC letter

Edit Jan 7, 2023:

Here's the cleaned up code that produces the compilation error shown at the end.

{-# LANGUAGE NamedFieldPuns #-}

module Alphametics (solve) where

import Control.Monad as M
import Control.Monad.Reader (ReaderT)
import qualified Control.Monad.Reader as R
import Control.Monad.ST (ST)
import qualified Control.Monad.ST as ST
import qualified Data.Char as C
import Data.List ((\\))
import Data.Set (Set)
import qualified Data.Set as S
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import Data.Vector.Unboxed.Mutable (MVector)
import qualified Data.Vector.Unboxed.Mutable as UM

solve :: String -> Maybe [(Char, Int)]
solve puzzle = error "You need to implement this function."

data PuzzleState s = PuzzleState
  { equation :: V.Vector (U.Vector Char),
    result :: U.Vector Char,
    nonZeroLetters :: Set Char,
    solution :: MVector s Int
  }

type M s = ReaderT (PuzzleState s) (ST s)

canSolve :: Int -> Int -> Int -> M s Bool
canSolve row col carry = do
  PuzzleState {equation, result, nonZeroLetters, solution} <- R.ask

  let addend = row < length equation
  let word = if addend then ((V.!) equation row) else result
  let n = length word
  let letter = (U.!) word col
  let x = ord letter
  y <- R.lift $ UM.read solution x
  let assigned = y >= 0
  let isNonZero = flip S.member nonZeroLetters
  let sumDigit = carry `mod` 10

  let used = filter (\i -> 0 <= UM.read solution i) [0 .. length solution - 1]

  case () of
    _
      | col >= n && addend -> canSolve (row + 1) col carry
      | col == n && (not addend) -> return $ carry == 0
      | addend && assigned -> canSolve (row + 1) col (carry + y)
      | addend ->
          let unused = filter (\i -> i == 0 && isNonZero letter) [0 .. 9] \\ used
           in assignAny unused y solution
      | assigned && sumDigit == y -> canSolve 0 (col + 1) (carry `mod` 10)
      | sumDigit `elem` used -> return $ False
      | sumDigit == 0 && isNonZero letter -> return $ False
      | otherwise -> assign 0 (col + 1) (carry `mod` 10) y sumDigit solution
  where
    ord x = C.ord x - C.ord 'A'
    assignAny [] _ _ = return (False)
    assignAny (i : xs) y solution = do
      success <- assign (row + 1) col (carry + i) y i solution
      if success then return (success) else assignAny xs y solution
    assign r c cr y i solution = do
      UM.write solution y i
      success <- canSolve r c cr
      M.when (not success) (UM.write solution y (-1))
      return (success)

Error:

• Couldn't match type ‘s’
                     with ‘primitive-0.7.3.0:Control.Monad.Primitive.PrimState m0’
      Expected: MVector
                  (primitive-0.7.3.0:Control.Monad.Primitive.PrimState (ST s)) Int
        Actual: MVector
                  (primitive-0.7.3.0:Control.Monad.Primitive.PrimState m0) Int
      ‘s’ is a rigid type variable bound by
        the type signature for:
          canSolve :: forall s. Int -> Int -> Int -> M s Bool
        at src/Alphametics.hs:31:1-41

Solution

  • OP here, figured it out myself. This code, and an alternative implementation using State monad, are available here. I’ve done some benchmarking, and surprisingly, the immutable version using State appears to be faster than the mutable code below.

    {-# LANGUAGE NamedFieldPuns #-}
    {-# LANGUAGE RecordWildCards #-}
    
    module Alphametics (solve) where
    
    import Control.Monad as M
    import Control.Monad.Reader (ReaderT)
    import qualified Control.Monad.Reader as R
    import Control.Monad.ST (ST)
    import qualified Data.Char as C
    import Data.List ((\\))
    import qualified Data.List as L
    import Data.Set (Set)
    import qualified Data.Set as S
    import qualified Data.Vector as V
    import qualified Data.Vector.Unboxed as U
    import qualified Data.Vector.Unboxed as VU
    import Data.Vector.Unboxed.Mutable (MVector)
    import qualified Data.Vector.Unboxed.Mutable as UM
    
    solve :: String -> Maybe [(Char, Int)]
    solve puzzle
      -- validate equation, "ABC + DEF == GH" is invalid,
      -- sum isn't wide enough
      | any (\x -> length x > (length . head) res) eqn = Nothing
      | otherwise = findSoln $ VU.create $ do
          let nonZeroLetters = S.fromList nz
          -- process in reverse
          let equation = (V.fromList . map (U.fromList . reverse)) eqn
          let result = (U.fromList . reverse . head) res
          solution <- UM.replicate 26 (-1)
          _ <- R.runReaderT (canSolve 0 0 0) PuzzleState {..}
          return solution
      where
        xs = filter (all C.isAsciiUpper) $ words puzzle
        (eqn, res) = L.splitAt (length xs - 1) xs
        -- leading letters can't be zero
        nz = [head x | x <- xs, length x > 1]
        chr x = C.chr (C.ord 'A' + x)
        findSoln v = case [ (chr x, y)
                            | x <- [0 .. 25],
                              let y = v VU.! x,
                              y >= 0
                          ] of
          [] -> Nothing
          x -> Just x
    
    data PuzzleState s = PuzzleState
      { equation :: V.Vector (U.Vector Char),
        result :: U.Vector Char,
        nonZeroLetters :: Set Char,
        solution :: MVector s Int
      }
    
    type M s = ReaderT (PuzzleState s) (ST s)
    
    canSolve :: Int -> Int -> Int -> M s Bool
    canSolve row col carry = do
      PuzzleState {equation, result, nonZeroLetters, solution} <- R.ask
    
      let addend = row < V.length equation
      let word = if addend then equation V.! row else result
      let n = U.length word
    
      case () of
        _
          | col >= n && addend -> canSolve (row + 1) col carry
          | col == n && not addend -> return $ carry == 0
          | otherwise -> do
              let letter = word U.! col
              let x = ord letter
              i <- readM solution x
              let assigned = i >= 0
              let canBeZero = flip S.notMember nonZeroLetters
              let sumDigit = carry `mod` 10
              used <- M.mapM (readM solution) [0 .. 25]
              let unused =
                    filter
                      (\y -> y > 0 || canBeZero letter)
                      [0 .. 9]
                      \\ used
    
              case () of
                _
                  | addend && assigned -> canSolve (row + 1) col (carry + i)
                  | addend -> assignAny solution x unused
                  | assigned ->
                      if sumDigit == i
                        then canSolve 0 (col + 1) (carry `div` 10)
                        else return False
                  | sumDigit `elem` used -> return False
                  | sumDigit == 0 && (not . canBeZero) letter -> return False
                  | otherwise ->
                      assign
                        0
                        (col + 1)
                        (carry `div` 10)
                        solution
                        x
                        sumDigit
      where
        -- lift is needed because we're working in in a ReaderT monad,
        -- whereas VM.read and VM.write work in the ST monad
        readM solution = R.lift . UM.read solution
        ord c = C.ord c - C.ord 'A'
        assignAny _ _ [] = return False
        assignAny solution ix (i : xs) = do
          success <- assign (row + 1) col (carry + i) solution ix i
          if success then return success else assignAny solution ix xs
        assign r c cr solution ix i = do
          UM.write solution ix i
          success <- canSolve r c cr
          M.unless success (UM.write solution ix (-1))
          return success