scalascala-catscontinuationscallcccps

Get current contiuation in Scala


Haskell has a function for getting the current continuation

getCC = callCC (\c -> let x = c x in return x)

How to write a similar function in Scala?

E.g. function callCC presents in cats.ContT. How could we use it?

I've tried many ways, but I can't make ends meet..


Solution

  • Let's implement this getCC ("get current continuation") in Scala.

    For starters, let's understand what is actually happening in Haskell. When we look at the docs and sources we'll find:

    newtype ContT (r :: k) (m :: k -> Type) a // or
    newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
    
    class Monad m => MonadCont (m :: Type -> Type) where
      callCC :: ((a -> m b) -> m a) -> m a
    
    instance forall k (r :: k) (m :: (k -> Type)) . MonadCont (ContT r m) where
      callCC = ContT.callCC
    
    MonadCont (ContT r m)
      callCC :: ((a -> ContT r m b) -> ContT r m a) -> ContT r m a
    

    So what we see here:

    Ok, so let's move to getCC part:

    Then you can make this method an extension method on ContT companion object, e.g. for Scala 3 it would be

    extension (contT: ContT.type)
      def getCC[M[_]: Defer, R, A]: ContT[M, R, ContT[M, R, A]] =
        ContT.callCC[M, R, ContT[M, R, A], A] { (c: (ContT[M, R, A] => ContT[M, R, A])) =>
          lazy val x: ContT[M, R, A] = ContT.later(c(x).run)
    
          ContT.pure[M, R, ContT[M, R, A]](x)
        }
    

    and as you already verified it works:

    import cats.data.ContT
    import cats.Defer
    
    extension (contT: ContT.type)
      def getCC[M[_]: Defer, R, A]: ContT[M, R, ContT[M, R, A]] =
        ContT.callCC[M, R, ContT[M, R, A], A] { (c: (ContT[M, R, A] => ContT[M, R, A])) =>
          lazy val x: ContT[M, R, A] = ContT.later(c(x).run)
    
          ContT.pure[M, R, ContT[M, R, A]](x)
        }
    
    // Testing
    
    import cats.effect.IO
    import cats.effect.unsafe.implicits.global
    import scala.concurrent.duration.{Duration, MILLISECONDS}
    import java.time.LocalDateTime
    
    val start = LocalDateTime.now()
    
    // It ticks for some times
    def loop =
      for {
        gotoLabel <- ContT.getCC[IO, Unit, Unit]
        _         <- ContT.liftF(IO.sleep(Duration(1000, MILLISECONDS)))
        _         <- ContT.liftF(IO.println("Tick"))
        now       <- ContT.liftF(IO{LocalDateTime.now()})
        _ <-      if (now.getSecond % 20 != 0) gotoLabel
                  else                         ContT.liftF(IO.unit)
      } yield ()
      
    loop.run(_ => IO.println("Done")).unsafeRunSync()
    

    I would not be completely sure it always works - I'd suggest putting a lot of tests precisely because we have to manually address this eager vs lazy value problem ourselves, but it should pretty much explain the idea.

    Notice, how much effort went into manually resolving all kind of type parameters. (And in understanding what is going on in general, it's pretty counterintuitive, and I have to learn this style anew every time I meet it again.) In Haskell type resolution works in a different way and it allows to just skip it (what other issues it brings I'll leave to Haskellers to explain). But it is definitely opaque to read, hard to debug, and difficult to maintain. I may see its use case in some internal logic that only a few people have to tinker with (and only if it actually brings some value!), but I'd definitely recommend against using it commonly in codebase and in business logic.