Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does this continuation-passing style Clojure function generator work?

This is from the Joy of Clojure, 2nd Edition. http://www.manning.com/fogus2/

 (defn mk-cps [accept? kend kont] 
   (fn [n] 
     ((fn [n k] 
        (let [cont (fn [v] (k ((partial kont v) n)))] 
          (if (accept? n) 
            (k 1) 
            (recur (dec n) cont)))) 
      n kend))) 

Then to make a factorial:

(def fac (mk-cps zero? identity #(* %1 %2)))

My understanding:

  • mm-cps generates a function which takes in n, the fn [n]
  • the function inside, fn [n k], is initially called with n and kend
  • the continuation function cont [v] is defined as (calling k with the partial application of kont with v) as the first parameter and n as the second parameter. Why would this be written using partial instead of simply (k (cont v n)) ?
  • if the accept? function passes, then finish the recursion, applying k to 1.
  • otherwise, the recur recurs back to fn [n k] with a decremented n, and with the continuation function.
  • all throughout, kont does not change.

Am I right that k isn't actually executed until the final (k 1)? So, (fac 3) is expanded first to (* 1 (* 2 3)) before being evaluated.

like image 242
mparaz Avatar asked Jan 31 '14 10:01

mparaz


People also ask

Why use continuation passing style?

Continuation passing style makes the control flow of programs more explicit as every procedure has the power to change the execution of the remainder of the program, contrast this to the traditional model in which procedures have no control over the behavior of the program once they return to their caller.

What does continuation passing style give you that tail recursion does not?

Continuation-Passing-Style, Tail Recursion, and Efficiency is not tail recursive, because the recursive call fact(n-1) is not the last thing the function does before returning. Instead, the function waits for the result of the recursive call, then multiples that by the value of n.

How do continuations work?

Continuations allow you to literally "jump" to different places in your code. They are a low-level primitive that gives you control over execution flow, allowing you implement everything from resumable exceptions to coroutines.

What is a continuation functional programming?

A continuation is a callback function k that represents the current state of the program's execution. More precisely, the continuation k is a function of one argument, namely the value that has been computed so far, that returns the final value of the computation after the rest of the program has run to completion.


1 Answers

I don't have the book, but I assume the motivating example is

(defn fact-n [n]
  (if (zero? n)
      1
      (* n (recur (dec n)))))

;=> CompilerException: Can only recur from tail position

And that last form has to be written (* n (fact-n (dec n))) instead, not tail-recursive. The problem is there is something remaining to be done after the recursion, namely multiplication by n.

What continuation passing style does is turn this inside out. Instead of applying what remains of the current context/continuation after the recursive call returns, pass the context/continuation into the recursive call to apply when complete. Instead of implicitly storing continuations on the stack as call frames, we explicitly accumulate them via function composition.

In this case, we add an additional argument k to our factorial, a function that does what we would have done after the recursive call returns.

(defn fact-nk [n k]
  (if (zero? n)
      (k 1)
      (recur (dec n) (comp k (partial * n)))))

The first k in is the last one out. Ultimately here we just want to return the value calculated, so the first k in should be the identity function.

Here's the base case:

(fact-nk 0 identity)
;== (identity 1)
;=> 1

Here's n = 3:

(fact-nk 3 identity)
;== (fact-nk 2 (comp identity (partial * 3)))
;== (fact-nk 1 (comp identity (partial * 3) (partial * 2)))
;== (fact-nk 0 (comp identity (partial * 3) (partial * 2) (partial * 1)))
;== ((comp identity (partial * 3) (partial * 2) (partial * 1)) 1)
;== ((comp identity (partial * 3) (partial * 2)) 1)
;== ((comp identity (partial * 3)) 2)
;== ((comp identity) 6)
;== (identity 6)
;=> 6

Compare to the non-tail recursive version

(fact-n 3)
;== (* 3 (fact-n 2))
;== (* 3 (* 2 (fact-n 1)))
;== (* 3 (* 2 (* 1 (fact-n 0))))
;== (* 3 (* 2 (* 1 1)))
;== (* 3 (* 2 1))
;== (* 3 2)
;=> 6

Now to make this a bit more flexible, we could factor out the zero? and the * and make them variable arguments instead.

A first approach would be

(defn cps-anck [accept? n c k]
  (if (accept? n)
      (k 1)
      (recur accept?, (dec n), c, (comp k (partial c n)))))

But since accept? and c are not changing, we could lift then out and recur to an inner anonymous function instead. Clojure has a special form for this, loop.

(defn cps-anckl [accept? n c k]
  (loop [n n, k k]
    (if (accept? n)
        (k 1)
        (recur (dec n) (comp k (partial c n))))))

And finally we might want to turn this into a function generator that pulls in n.

(defn gen-cps [accept? c k]
  (fn [n]
    (loop [n n, k k]
      (if (accept? n)
          (k 1)
          (recur (dec n) (comp k (partial c n)))))))

And that is how I would write mk-cps (note: last two arguments reversed).

(def factorial (gen-cps zero? * identity))
(factorial 5)
;=> 120

(def triangular-number (gen-cps #{1} + identity))    
(triangular-number 5)
;=> 15
like image 74
A. Webb Avatar answered Oct 20 '22 02:10

A. Webb