haskellcontinuation-passing

Is it possible to convert a HOAS function to continuation passing style?


Mind the following Haskell program:

-- A HOAS term, non-polymorphic for simplicity
data Term
  = Lam (Term -> Term)
  | App Term Term
  | Num Int

-- Doubles every constant in a term
fun0 :: Term -> Term
fun0 (Lam b)   = Lam (\ x -> fun0 (b x))
fun0 (App f x) = App (fun0 f) (fun0 x)
fun0 (Num i)   = Num (i * 2)

-- Same function, using a continuation-passing style
fun1 :: Term -> (Term -> a) -> a
fun1 (Lam b)   cont = undefined
fun1 (App f x) cont = fun1 f (\ f' -> fun1 x (\ x' -> cont (App f' x')))
fun1 (Num i)   cont = cont (Num (i * 2))

-- Sums all nums inside a term
summ :: Term -> Int
summ (Lam b)   = summ (b (Num 0))
summ (App f x) = summ f + summ x
summ (Num i)   = i

-- Example
main :: IO ()
main = do
  let term = Lam $ \ x -> Lam $ \ y -> App (App x (Num 1)) (App y (Num 2))
  print (summ term)                 -- prints 3
  print (summ (fun0 term))          -- prints 6
  print (fun1 term $ \ t -> summ t) -- a.hs: Prelude.undefined 

Here, Term is a (non-polymorphic) λ-term with numeric constants, and fun0 is a function that doubles all constants inside a term. Is it possible to rewrite fun0 in a continuation-passing style? In other words, is it possible to complete the undefined case of the fun1 function such that it behaves identically to fun0 and such that the last print outputs 6?


Solution

  • If you want to convert this function to CPS, you need to also convert the function within the data type:

    data Term' a
      = Lam' (Term' a -> (Term' a -> a) -> a)
      | App' (Term' a) (Term' a)
      | Num' Int
    

    Then you can write your fun1 accordingly:

    fun1 :: Term' a -> (Term' a -> a) -> a
    fun1 (Lam' b)   cont = cont (Lam' (\ x cont' -> b x cont'))
    fun1 (App' f x) cont = fun1 f (\ f' -> fun1 x (\ x' -> cont (App' f' x')))
    fun1 (Num' i)   cont = cont (Num' (i * 2))
    

    And with the appropriate tweak to summ:

    summ' :: Term' Int -> Int
    summ' (Lam' b)   = b (Num' 0) summ'
    summ' (App' f x) = summ' f + summ' x
    summ' (Num' i)   = i
    

    As well as a CPS term:

    term' = Lam' $ \ x k -> k $ Lam' $ \ y k' -> k' $ App' (App' x (Num' 1)) (App' y (Num' 2))
    

    You can run the computation just fine:

    > fun1 term' summ'
    3