Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert normal recursion to tail recursion

Tags:

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.

For example (scala):

def pascal(c: Int, r: Int): Int = {  if (c == 0 || c == r) 1  else pascal(c - 1, r - 1) + pascal(c, r - 1) } 

A general solution for functional languages to convert recursive function to a tail-call equivalent:

A simple way is to wrap the non tail-recursive function in the Trampoline monad.

def pascalM(c: Int, r: Int): Trampoline[Int] = {  if (c == 0 || c == r) Trampoline.done(1)  else for {      a <- Trampoline.suspend(pascal(c - 1, r - 1))      b <- Trampoline.suspend(pascal(c, r - 1))    } yield a + b }  val pascal = pascalM(10, 5).run 

So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.

A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads

like image 375
DennisVDB Avatar asked Sep 22 '13 14:09

DennisVDB


People also ask

Can every recursion be converted into tail recursion?

No, it is not possible to express all recursion as tail recursion unless you do supplement the tail recursion with other control flow mechanisms.

Why do we convert a recursive function to a tail recursive function?

The idea used by compilers to optimize tail-recursive functions is simple, since the recursive call is the last statement, there is nothing left to do in the current function, so saving the current function's stack frame is of no use (See this for more details).

How is tail recursion different from ordinary recursion?

The tail recursion is better than non-tail recursion. As there is no task left after the recursive call, it will be easier for the compiler to optimize the code. When one function is called, its address is stored inside the stack. So if it is tail recursion, then storing addresses into stack is not needed.


2 Answers

In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:

def recur[A](...):List[A] = {   ...   x :: recur(...) } 

which is not tail recursive, is transformed into

def recur[A]{...): List[A] = {    def consRecur(..., consA: A): List[A] = {      consA :: ...      ...      consrecur(..., ...)    }    ...    consrecur(...,...) } 

Alexlv's example is a variant of this.

This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.

Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.

As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.

def fib (n: Long): Long = n match {   case 0 | 1 => n   case _ => fib( n - 2) + fib( n - 1 ) }  def fib (n: Long): Long = {   def loop(current: Long, next: => Long, iteration: Long): Long = {     if (n == iteration)        current     else       loop(next, current + next, iteration + 1)   }   loop(0, 1, 0) } 

For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now, you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.

Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:

def pascal(column: Long, row: Long):Long = {   type Point = (Long, Long)   type Points = List[Point]   type Triangle = Map[Point,Long]   def above(p: Point) = (p._1, p._2 - 1)   def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)   def find(ps: Points, t: Triangle): Long = ps match {     // Found the ultimate goal     case (p :: Nil) if t contains p => t(p)     // Found an intermediate point: pop the stack and carry on     case (p :: rest) if t contains p => find(rest, t)     // Hit a triangle edge, add it to the triangle     case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))     // Triangle contains (c - 1, r - 1)...     case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))         // And it contains (c, r - 1)!  Add to the triangle         find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))       else         // Does not contain(c, r -1).  So find that         find(above(p) :: ps, t)     // If we get here, we don't have (c - 1, r - 1).  Find that.     case (p :: _) => find(aboveLeft(p) :: ps, t)   }   require(column >= 0 && row >= 0 && column <= row)   (column, row) match {     case (c, r) if (c == 0) || (c == r) => 1     case p => find(List(p), Map())   } } 

It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.

You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.

like image 137
itsbruce Avatar answered Sep 18 '22 17:09

itsbruce


I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.

Instead, consider using a Stream or memoization:

val pascal: Stream[Stream[Long]] =    (Stream(1L)      #:: (Stream from 1 map { i =>        // compute row i       (1L          #:: (pascal(i-1) // take the previous row                sliding 2 // and add adjacent values pairwise                collect { case Stream(a,b) => a + b }).toStream          ++ Stream(1L))     })) 
like image 26
Aaron Novstrup Avatar answered Sep 18 '22 17:09

Aaron Novstrup