functional-programmingf#continuationscps

F# CPS evaluation order


I'm trying to understand the order of evaluation when using Continuation-passing style in F#. Take this function for example.

let rec CPSfunc n k c =
    if k = 0 then c 1
    else if k > 0 then CPSfunc n (k-1) (fun res -> c(2*res+k))

When running it with the arguments CPSfunc 4 3 id it evaluates to 19, but when I try to evaluate it by hand, I get different results, based on evaluating forward or backwards first.

CPSfunc 4 3 (fun res -> res)
CPSfunc 4 2 (fun res -> fun 2*res+3 -> res)
CPSfunc 4 1 (fun res -> fun 2*res+2 -> fun 2*res+3 -> res)
// Evaluating backwards
fun res -> fun 2*res+2 -> fun 2*res+3 -> 1
fun res -> fun 2*res+2 -> 2*1+3
fun res -> 2*5+2
// Evaluating forward
fun 1 -> fun 2*res+2 -> fun 2*res+3 -> res
fun 2*1+2 -> fun 2*res+3 -> res
fun 2*4+3 -> res
4

How do I properly calculate the correct output?


Solution

  • To see that 19 is the correct result, I think it's easiest to start with k = 0 and increment. Each result is simply twice the previous result, plus k. (Note that n is not used.) So we have:

    k = 0 ->     1     =  1
    k = 1 -> 2 * 1 + 1 =  3
    k = 2 -> 2 * 3 + 2 =  8
    k = 3 -> 2 * 8 + 3 = 19
    

    Converting that simple logic into continuations gets complicated, though. Here's what the expansion looks like in F# for CPSfunc 4 3 id:

    // unexpanded initial call
    let c3 = (fun res -> res)
    CPSfunc 4 3 c3
    
    // expanded once, k = 3
    let c2 = (fun res -> c3 (2 * res + 3))
    CPSfunc 4 2 c2
    
    // expanded again, k = 2
    let c1 = (fun res -> c2 (2 * res + 2))
    CPSfunc 4 1 c1
    
    // expanded again, k = 1
    let c0 = (fun res -> c1 (2 * res + 1))
    CPSfunc 4 0 c0
    
    // full expansion, k = 0
    c0 1
    

    P.S. To make c have the desired int -> int signature, you need to define CPSfunc slightly differently, so that's what I assume you've actually done:

    let rec CPSfunc n k c =
        if k = 0 then c 1
        elif k > 0 then CPSfunc n (k-1) (fun res -> c(2*res+k))
        else failwith "k < 0"