haskellbinary-treecontinuation-passingcps

Continuation Passing Style in Haskell with Binary Tree


I'm just learning CPS and I'm trying to pass the following program to this style

mirror Void = Void 
mirror (Node x left right) = Node x (mirror right) (mirror left)

From what I understand I have to pass the continuation the base case and in the recursive case build the construction with lambdas. for the case of a sum it would be

cpsadd n 0 k = k n
cpsadd n succ(m) k = cpsadd n m (\v -> k succ(v)) 
--where k is the continuation, i.e, the stack.

Another example with list

mult [] k = k 1
mult (x:xs) k = mult xs (\v -> k (x*v))

In that sense I had the idea the first idea

mirror Void k = k Void
mirror (Node x l r) k = Node x (\v -> k r v) (\w -> k v r)

But immediately I realized that I am building the tree without passing the continuation k. So I had the second idea

mirror Void k = k Void
mirror (Node x l r) k = mirror Node x (\v -> k r v l)

Now I do pass the continuation, but when I test it (by hand) I don't get to the base case, so it didn't work either. And it confuses me that I have to call the recursive function twice and flip them to make the mirror.

Any idea? Thankss!


Solution

  • One basic transformation you need to perform often to get to CPS is turning

    f x y = g (h x) y -- non-CPS
    

    into

    f' x y k = h' x (\r -> g' r y) -- CPS
    

    That is, anytime you need to call a function in the middle of an expression, you instead call the CPS version of that function, and give it as continuation a lambda which finishes the expression. So let's start with your definition for mirror, and work towards a CPS implementation.

    mirror Void = Void 
    mirror (Node x left right) = Node x (mirror right) (mirror left)
    

    I'll write [e] to denote that the expression e needs to be converted to CPS form, and just e if it has already been transformed. First let's add the k argument, and wrap the implementations in brackets to indicate they need to be transformed:

    mirror Void k = [Void]
    mirror (Node x left right) k = [Node x (mirror right) (mirror left)]
    

    Transforming the Void case is easy: you did that already.

    mirror Void k = k Void
    mirror (Node x left right) k = [Node x (mirror right) (mirror left)]
    

    Now we need to address the first recursive call to mirror right. We call it immediately, and then give it a lambda (which isn't fully converted yet):

    mirror Void k = k Void
    mirror (Node x left right) k = mirror right r
      where r right' = [Node x right' (mirror left)]
    

    Now the body of r has a call to mirror left that needs to be lifted out:

    mirror Void k = k Void
    mirror (Node x left right) k = mirror right r
      where r right' = mirror left l
              where l left' = [Node x right' left']
    

    Now the body of l has no recursive calls left, and has exactly the value you wanted to pass to k to begin with, so the final transformation is easy: just call k.

    mirror Void k = k Void
    mirror (Node x left right) k = mirror right r
      where r right' = mirror left l
              where l left' = k $ Node x right' left'
    

    If you like, you can write that with lambdas instead of where clauses, but the principle is the same.