haskell

Could someone explain to me what these Iterator, Yield monad types and functions mean like I am 5?


Below is the code. I understand applicative, functor, traversable and monad to a certain extent. The iterator and yield types, and yield functions are what I struggle understanding the most. For example, what are i o r in Iterator and i o r a in Yield? what does traverseY do exactly and what do the signatures mean? and How Monad Cont is applied here also confuses me a lot. Thank you for reading, any input would be appreciated.

import Control.Monad.Cont
import Control.Monad.Random
import System.Random.Shuffle


data Iterator i o r =
    Result r
  | Susp o (i -> Iterator i o r)
--------------------------------------------------------------------------------------------
newtype Yield i o r a = Yield { unY :: Cont (Iterator i o r) a }

instance Functor (Yield i o r) where
  fmap f (Yield m) = Yield (fmap f m)

instance Applicative (Yield i o r) where
  pure a = Yield (pure a)
  (Yield mf) <*> (Yield ma) = Yield (mf <*> ma)

instance Monad (Yield i o r) where
  return a = Yield (return a)
  (Yield m) >>= k = Yield (m >>= \a -> unY (k a))

instance MonadCont (Yield i o r) where
  callCC c = Yield (callCC (\k -> unY (c (\a -> Yield (k a)))))

runYield :: Yield i o r r -> Iterator i o r
runYield (Yield m) = runCont m Result

yield :: o -> Yield i o r i
yield o = callCC (\k -> Yield (cont (\_ -> Susp o (\i -> runYield (k i)))))

-------------------------------------------------------------------------------------------
data Tree a = Empty | Node a (Tree a) (Tree a)

traverseY :: Tree a -> Yield (Tree a) (a,Tree a,Tree a) r ()
traverseY Empty = return ()
traverseY (Node a t1 t2) = do t <- yield (a, t1, t2); traverseY t

Solution

  • An Iterator i o r represents a process that repeatedly outputs values of type o while consuming one value of type i after each o, eventually breaking the iteration by returning an r instead of an o. E.g.

    accum :: Iterator Int Int Int
    accum = go 0
      where go n | n >= 100 = Result n
                 | otherwise = Susp n (\i -> go (n + i))
    
    -- potential usage
    main :: IO ()
    main = go accum -- iterator returns modified copies instead of mutating, so prepare to replace the iterator through iteration/recursion
      where go (Result r) = putStrLn $ "Final total: " ++ show r -- check whether iterator has values/extract values by pattern matching; a finished iterator can return extra data of type r if it likes
            go (Susp o i) = do
              putStrLn $ "Running total: " ++ show o
              putStr $ "Add: "
              n <- readLn
              go (i n) -- bidirectional communication! get "incremented" iterator by feeding in an input value (you could write no-input iterators by setting i = ())
    

    is an iterator that keeps track of some accumulator n. Each iteration, it outputs the current accumulator and then waits for an input which it adds to the accumulator. Once the accumulator reaches 100, it stops accepting input and returns the final value. Note that, since everything is immutable, in order to keep state, the Iterator has to return a new version of itself every time it changes state. Whoever is iterating "through" accum in turn has to use the returned Iterator instead of accum itself. In Python, you could write accum as:

    def accum(): # calling accum() instead creates a new object that mutates itself through iteration
      sum = 0
      while sum < 100: sum = sum + (yield sum)
      return sum
    
    # same usage
    def main():
      gen = accum()
      try: # Python doesn't have a shorthand syntax for using these iterators, but this should be legible enough
        o = next(gen)
        while True:
          print(f"Running total: {o}")
          o = gen.send(int(input("Add: ")))
      except StopIteration as e:
        print(f"Final total: {e.value}")
    

    Yield i o r r is being used as a "builder" for Iterator i o r. You can write an analogue of Yield's interface for Iterator directly:

    instance Functor (Iterator i o) where
      fmap f (Result x) = Result (f x)
      fmap f (Susp o i) = Susp o (fmap f . i)
    instance Applicative (Iterator i o) where
      pure = Result
      liftA2 = liftM2
    instance Monad (Iterator i o) where
      Result r >>= f = f r
      Susp o i >>= f = Susp o ((>>= f) . i)
    
    -- yieldI x is the iterator that sends x to the caller and receives a value in return
    -- in Python: def yieldI(x): return yield x
    yieldI :: o -> Iterator i o i
    yieldI x = Susp x Result
    

    E.g. the generator in your DFS example is

    data Tree a = Empty | Node a (Tree a) (Tree a)
    dfsI :: Tree a -> Iterator b a (Tree b) -- yield elements of the tree (o = a) *and also* receive new elements to replace them (i = b), building a new tree (r = Tree b)
    -- dfs = traverse yieldI
    dfsI Empty = Result Empty
    dfsI (Node l m r) = do
      l' <- yieldI l
      m' <- dfsI m
      r' <- dfsI r
      return (Node l' m' r')
    

    The issue with using Iterator i o r directly here is that it is inefficient. (Remember that Haskell is lazily evaluated to understand the following.) If you "concatenate" many iterators, like ((x >>= f) >>= g) >>= h, then you run into trouble when you try to evaluate it. Say x evaluates to Susp o i. Then evaluating the bigger expression first does three function calls, into the >>=s, evaluates x to Susp o i, then creates new suspended function calls to produce Susp o (\a -> ((i a >>= f) >>= g) >>= h). When you iterate through this iterator (i.e. extract the lambda and call it with some argument), each iteration must walk through all the >>=s that are hanging on to the Iterator. Yikes. (Stated perhaps in more familiar terms, we've implemented iterator concatenation as a "wrapper" around another, which gets bad when you have wrappers on wrappers on wrappers...)

    Using Cont is a "standard trick" for avoiding this. The idea is that, instead of handling an iterator x directly, we handle its bind function (wrapped in Cont and Yield newtypes) x' = \f -> x >>= f. Note that converting a monadic computation to its bind function is reversible, since x' return = x >>= return = x.

    newtype Yield i o r a = Yield { unYield :: (a -> Iterator i o r) -> Iterator i o r }
    instance Functor (Yield i o r) where
      fmap f (Yield r) = Yield (\k -> r $ k . f)
    instance Applicative (Yield i o r) where
      pure x = Yield (\k -> k x)
      liftA2 = liftM2
    instance Monad (Yield i o r) where
      Yield r >>= f = Yield (\k -> r $ \x -> unYield (f x) k)
    -- newtype Yield i o r a = Yield { unYield :: (a -> Iterator i o r) -> Iterator i o r }
    -- compare         (>>=) :: Iterator i o a -> (a -> Iterator i o r) -> Iterator i o r
    -- giving
    liftIY :: Iterator i o a -> Yield i o r a
    liftIY x = Yield (x >>=)
    -- and the law x >>= return = x inspires
    runYield :: Yield i o r r -> Iterator i o r
    runYield (Yield r) = r return
    -- e.g.
    yield :: o -> Yield i o r i
    -- yield = liftIY . yieldI
    yield x = Yield (\k -> Susp x (\i -> k i)) -- note this is the same as yours after you drill through newtype baggage
    

    Instead of having ((x >>= f) >>= g) >>= h, using Yield will make terms like (((x' >>= liftIY . f) >>= liftIY . g) >>= liftIY . h) return appear at runtime. This evaluates to just x' >>= _someFunction, so all the wrappers from before have collapsed into just one, hopefully leading to efficiency. (This collapsed wrapper will "replace itself" as you iterate through, going through the behaviors specified by f, g, and h in turn. This is encoded in Yield's >>=.)

    -- magical free efficiency by replacing Iterator -> Yield, yieldI -> yield
    dfs :: Tree a -> Yield b a r (Tree b)
    dfs = traverse yield
    -- when using Yields as builders, you will treat Yield i o _ r like Iterator i o r
    -- the middle parameter should be polymorphic for builders like dfs
    -- while the final consumer (particularly the "standard consumer" runYield) fixes it to something (runYield sets it to the final return type)
    -- (this behavior is a loosely typed reflection of the Codensity monad)
    

    At the final use site, Yields have to be made into Iterators before they can be iterated. Your dfsDirect:

    1. Uses dfs = traverse yieldI to build a Yield without incurring the wrath of "accidentally quadratic left associativity" (technical term :))
    2. Builds an Iterator out of that Yield using runYield. This iterator goes through tr and yields/replaces its elements.
    3. Iterates that iterator via loop, which...
    4. "Converses" with dfs via the line loop (Susp o i) = loop (i [o]): when it receives o it sends [o] back into the iterator, which puts it in the Tree it's building.
    5. Upon exhausting the iterator, receives a new Tree where every element is replaced with a singleton list (loop (Result r) = _).
    6. Concatenates all the lists together via the Foldable Tree instance.

    This is a pretty dumb way to do things, since the order doesn't come from dfs but from the Foldable Tree instance used in the last step. The Iterator is just used as a glorified fmap function. If the Foldable instance were different (e.g. BFS, or even just inorder vs preorder), but you kept dfs as a preorder DFS (so it would no longer be written traverse yield), dfsDirect would not output in the actual order defined by dfs! You could write a function that properly turns an Iterator into a list.

    -- note the usage of () as an input type for "plain" iterators
    -- since we cannot know what to pass in otherwise
    -- I am also being careful about strictness: forcing (iToList i) to the (,) constructor does not force the iterator i at all, the elements of the list are properly lazily evaluated, and evaluating the r component forces the whole iteration at once
    iToList :: Iterator () o r -> ([o], r)
    iToList i = let (xs, r) = go i in (xs, r)
      where go (Result r) = ([], r)
            go (Susp x i) = let (xs, r) = go (i ()) in (x : xs, r)
    

    traverseY is also a bit strange. If it receives a Node (either as initial value or as an iterator input), it yields back the fields of the Node, and on Empty else it returns. It doesn't actually "traverse" its input; you can send it off into a completely new tree just by sending that tree as input. I assume the idea is that when you iterate over it, you send back one of the Trees it previously returned, so it iterates over a path to a leaf. IMO it would be nicer to write

    data Direction = L | R
    path :: Tree a -> Yield Direction a r () -- outputs root and goes in direction as told until it runs out of tree
    path Empty = pure ()
    path (Node x l r) = do
      d <- yield x
      case d of
        L -> path l
        R -> path r
    -- potential use
    elemBST :: Ord a => a -> Tree a -> Bool
    elemBST x xs = go (runYield $ path xs)
      where go (Result ()) = False -- iteration ended without success
            go (Susp y i) = case compare x y of
              LT -> go (i L) -- go left
              EQ -> True     -- end
              GT -> go (i R) -- go right