Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to use continuations to make foldRight tail recursive?

The following blog article shows how in F# foldBack can be made tail recursive using continuation passing style.

In Scala this would mean that:

def foldBack[T,U](l: List[T], acc: U)(f: (T, U) => U): U = {
  l match {
    case x :: xs => f(x, foldBack(xs, acc)(f))
    case Nil => acc
  }
} 

can be made tail recursive by doing this:

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    l match {
      case x :: xs => loop(xs, (racc => k(f(x, racc))))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

Unfortunately, I still get a stack overflow for long lists. loop is tail recursive and optimized but I guess the stack accumulation is just moved into the continuation calls.

Why is this not a problem with F#? And is there any way to work around this with Scala?

Edit: here some code that shows depth of stack:

def showDepth(s: Any) {
  println(s.toString + ": " + (new Exception).getStackTrace.size)
}

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    showDepth("loop")
    l match {
      case x :: xs => loop(xs, (racc => { showDepth("k"); k(f(x, racc)) }))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

foldCont(List.fill(10)(1), 0)(_ + _)

This prints:

loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
k: 51
k: 52
k: 53
k: 54
k: 55
k: 56
k: 57
k: 58
k: 59
k: 60
res2: Int = 10
like image 445
huynhjl Avatar asked Dec 18 '11 02:12

huynhjl


People also ask

Is foldLeft tail recursive?

The foldLeft and product methods are tail-recursion optimized already, so they solve the problem with recursion without leaking their detals to the caller.

What does continuation passing style give you that tail recursion does not?

Continuation-Passing-Style, Tail Recursion, and Efficiency is not tail recursive, because the recursive call fact(n-1) is not the last thing the function does before returning. Instead, the function waits for the result of the recursive call, then multiples that by the value of n.


1 Answers

Jon, n.m., thank you for your answers. Based on your comments I thought I'd give a try and use trampoline. A bit of research shows Scala has library support for trampolines in TailCalls. Here is what I came up with after a bit of fiddling around:

def foldContTC[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  import scala.util.control.TailCalls._
  @annotation.tailrec
  def loop(l: List[T], k: (U) => TailRec[U]): TailRec[U] = {
    l match {
      case x :: xs => loop(xs, (racc => tailcall(k(f(x, racc)))))
      case Nil => k(acc)
    }
  }
  loop(list, u => done(u)).result
} 

I was interested to see how this compares to the solution without the trampoline as well as the default foldLeft and foldRight implementations. Here is the benchmark code and some results:

val size = 1000
val list = List.fill(size)(1)
val warm = 10
val n = 1000
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldCont", warm, lots(n, foldCont(list, 0)(_ + _)))
bench("foldRight", warm, lots(n, list.foldRight(0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))

The timings are:

foldContTC: warming...
Elapsed: 0.094
foldCont: warming...
Elapsed: 0.060
foldRight: warming...
Elapsed: 0.160
foldLeft: warming...
Elapsed: 0.076
foldLeft.reverse: warming...
Elapsed: 0.155

Based on this, it would seem that trampolining is actually yielding pretty good performance. I suspect that the penalty on top of the boxing/unboxing is relatively not that bad.

Edit: as suggested by Jon's comments, here are the timings on 1M items which confirm that performance degrades with larger lists. Also I found out that library List.foldLeft implementation is not overriden, so I timed with the following foldLeft2:

def foldLeft2[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  list match {
    case x :: xs => foldLeft2(xs, f(x, acc))(f)
    case Nil => acc
  }
} 

val size = 1000000
val list = List.fill(size)(1)
val warm = 10
val n = 2
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft2", warm, lots(n, foldLeft2(list, 0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))
bench("foldLeft2.reverse", warm, lots(n, foldLeft2(list.reverse, 0)(_ + _)))

yields:

foldContTC: warming...
Elapsed: 0.801
foldLeft: warming...
Elapsed: 0.156
foldLeft2: warming...
Elapsed: 0.054
foldLeft.reverse: warming...
Elapsed: 0.808
foldLeft2.reverse: warming...
Elapsed: 0.221

So foldLeft2.reverse is the winner...

like image 186
huynhjl Avatar answered Sep 24 '22 00:09

huynhjl