haskellcontinuationscontinuation-passing

Have I traced the continuations right?


I am learning the concept of continuations. I thought a good idea was to trace the typical factorial code when written with the continuation form. I want to know if my understanding is correct.

Code:

fact :: Integer -> (Integer -> Integer) -> Integer
fact 0 k = k 1
fact n k = fact $ n-1 (\v -> k (n*v))

-- call: fact 5 id
-- answer: 120

This is how I traced the code (I think being able to trace it is fundamental to understanding how continuations work):

fact 5 id --> 
fact 4 (\v -> id (5*v)) --> 
fact 3 (\v -> (\v -> id (5*v)) (4*v)) --> 
fact 2 (\v -> (\v -> (\v -> id (5*v)) (4*v)) (3*v)) --> 
fact 1 (\v -> (\v -> (\v -> (\v -> id (5*v)) (4*v)) (3*v))) (2*v)) --> 
fact 0 (\v -> (\v -> (\v -> (\v -> (\v -> id (5*v)) (4*v)) (3*v))) (2*v)) (1*v)) --> 
(\v -> (\v -> (\v -> (\v -> (\v -> id (5*v)) (4*v)) (3*v))) (2*v)) (1*v)) 1

Is this how it is supposed to trace out or do I have the fundamentals wrong?

PS: I understand the vs are a little confusing, but I am assuming that the inner v is shadowing the outer v?


Solution

  • Yes, you have the fundamentals correct. The only thing slightly incorrect is the function you start with. The inner v is indeed shadowing the outer v, which you can find out by turning all GHC warnings on. For making things easier to follow, you might want to give the variables unique names. Here's how I would write it, with the corrected initial function:

    fact :: Integer -> (Integer -> Integer) -> Integer
    fact 0 k = k 1
    fact n k = fact (n-1) (\v -> k (n*v))
    
    -- call: fact 5 id
    main :: IO ()
    main = do
        print $ fact 5 id
        print $ fact 4 (\v0 -> id (5*v0))
        print $ fact 4 (\v0 -> 5*v0)
        print $ fact 4 (\v0 -> 5*v0)
        print $ fact 3 (\v1 -> (\v0 -> 5*v0) (4*v1))
        print $ fact 2 (\v2 -> (\v1 -> (\v0 -> 5*v0) (4*v1)) (3*v2))
        print $ fact 1 (\v3 -> (\v2 -> (\v1 -> (\v0 -> 5*v0) (4*v1)) (3*v2)) (2*v3))
        print $ fact 0 (\v4 -> (\v3 -> (\v2 -> (\v1 -> (\v0 -> 5*v0) (4*v1)) (3*v2)) (2*v3)) (1*v4))
        print $ (\v4 -> (\v3 -> (\v2 -> (\v1 -> (\v0 -> 5*v0) (4*v1)) (3*v2)) (2*v3)) (1*v4)) 1