haskellrecursionmonadsmonad-transformersstate-monad

Refactor impure recursion with state monad?


I've been dissecting this one-liner solution for aoc day 14 and came across an elegant impure recursive solution:

def s(x,y):
    if y > h: return True
    if (x, y) in m: return False
    return next((r for d in (0,-1,1) if (r:=s(x+d,y+1))), None) or m.add((x,y))

full solution on godbolt

One way you could make this pure is by explicitly passing and returning the set m from the function s (i.e. s :: int -> int -> set -> (bool, set)).

However, I've also read about how the reader/writer/state monads save you from having to pass the extra parameter and handle the tuple result an am interested in porting this recursion to haskell.

I found a haskell solution on the reddit that looks like it may do the same recursion (as well as two more that don't).

fill :: (MArray a Bool (ST s), Ix i, Num i, Show i) => a (i, i) Bool -> i -> ST s (Int, Int)
fill blocks maxY = do
    counterAtMaxY <- newSTRef Nothing
    counter <- newSTRef 0
    let fill' (x, y) = readArray blocks (x, y) >>= flip bool (pure ()) do
            when (y == maxY) $ readSTRef counterAtMaxY >>= maybe
                (readSTRef counter >>= writeSTRef counterAtMaxY . Just) (const $ pure ())
            when (y <= maxY) $ fill' (x, y + 1) >> fill' (x - 1, y + 1) >> fill' (x + 1, y + 1)
            writeArray blocks (x, y) True >> modifySTRef' counter (+ 1)
    fill' (500, 0)
    counterAtMaxY <- readSTRef counterAtMaxY
    counter <- readSTRef counter
    pure (fromMaybe counter counterAtMaxY, counter)

full solution on godbolt

Could someone confirm that this indeed is a port of the python solution. If so could they baby me through following how the recursion is happening?

I still am not Haskell literate. I can kind of make out that fill' (500, 0) means m >>= \_ -> fill' (500, 0), which means discard the current state, and create a new monad independently (something gets preserved but I'm confused what)??. I also don't understand monad transformers at all.

The Haskell solution does part 2 of the question simultaneously, so maybe someone can factor that out so there's no confusion between the cartesian coordinates and the pair of ints containing the solution.


Solution

  • Below is a fairly close translation of your Python code to Haskell. Some remarks on the differences:

    module Main where
    
    import Control.Monad.State
    import Data.Set (Set)
    import qualified Data.Set as Set
    
    type M = State (Set (Int, Int))
    
    s :: Int -> Int -> Int -> M Bool
    s h x y =
      if y > h then pure True
      else do
        m <- get
        if Set.member (x, y) m then
          pure False
        else
          orM ([s h (x+d) (y+1) | d <- [0, -1, 1]] ++ [add (x, y) *> pure False])
    
    orM :: Monad m => [m Bool] -> m Bool
    orM [] = pure False
    orM (x : xs) = do
      b <- x
      if b then pure True
      else orM xs
    
    add :: (Int, Int) -> M ()
    add (x, y) = modify (Set.insert (x, y))
    
    -- Example from https://adventofcode.com/2022/day/14
    
    m0 :: Set (Int, Int)
    m0 = vline 498 4 6 <> hline 498 496 6 <> hline 503 502 4 <> vline 502 4 9 <> hline 494 502 9
    
    vline, hline :: Int -> Int -> Int -> Set (Int, Int)
    vline x y1 y2 | y1 > y2 = vline x y2 y1
    vline x y1 y2 = Set.fromList [(x, y) | y <- [y1 .. y2]]
    
    hline x1 x2 y | x1 > x2 = hline x2 x1 y
    hline x1 x2 y = Set.fromList [(x, y) | x <- [x1 .. x2]]
    
    h0 :: Int
    h0 = 9
    
    main :: IO ()
    main =
      print (Set.size (execState (s h0 500 0) m0) - Set.size m0)
      -- Output: 24