Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How exactly does tail recursion work?

The compiler is simply able to transform this

int fac_times (int n, int acc) {
    if (n == 0) return acc;
    else return fac_times(n - 1, acc * n);
}

into something like this:

int fac_times (int n, int acc) {
label:
    if (n == 0) return acc;
    acc *= n--;
    goto label;
}

You ask why "it doesn't require stack to remember its return address".

I would like to turn this around. It does use the stack to remember the return address. The trick is that the function in which the tail recursion occurs has its own return address on the stack, and when it jumps to the called function, it will treat this as it's own return address.

Concretely, without tail call optimization:

f: ...
   CALL g
   RET
g:
   ...
   RET

In this case, when g is called, the stack will look like:

   SP ->  Return address of "g"
          Return address of "f"

On the other hand, with tail call optimization:

f: ...
   JUMP g
g:
   ...
   RET

In this case, when g is called, the stack will look like:

   SP ->  Return address of "f"

Clearly, when g returns, it will return to the location where f was called from.

EDIT: The example above use the case where one function calls another function. The mechanism is identical when the function calls itself.


Tail recursion can usually be transformed into a loop by the compiler, especially when accumulators are used.

// tail recursion
int fac_times (int n, int acc = 1) {
    if (n == 0) return acc;
    else return fac_times(n - 1, acc * n);
}

would compile to something like

// accumulator
int fac_times (int n) {
    int acc = 1;
    while (n > 0) {
        acc *= n;
        n -= 1;
    }
    return acc;
}

There are two elements that must be present in a recursive function:

  1. The recursive call
  2. A place to keep count of the return values.

A "regular" recursive function keeps (2) in the stack frame.

The return values in regular recursive function are composed of two types of values:

  • Other return values
  • Result of the owns function computation

Let's look at your example:

int factorial (int n) {
    if (n == 0) return 1;
    else return n * factorial(n - 1);
}

The frame f(5) "stores" the result of it's own computation (5) and the value of f(4), for example. If i call factorial(5), just before the stack calls begin to collapse, i have:

 [Stack_f(5): return 5 * [Stack_f(4): 4 * [Stack_f(3): 3 * ... [1[1]]

Notice that each stack stores, besides the values i mentioned, the whole scope of the function. So, the memory usage for a recursive function f is O(x), where x is the number of recursive calls i have to made. So, if i needb 1kb of RAM to calculate factorial(1) or factorial(2), i need ~100k to calculate factorial(100), and so on.

A Tail Recursive function put (2) in it's arguments.

In a Tail Recursion, i pass the result of the partial calculations in each recursive frame to the next one using parameters. Let's see our factorial example, Tail Recursive:

int factorial (int n) { int helper(int num, int accumulated) { if num == 0 return accumulated else return helper(num - 1, accumulated*num) } return helper(n, 1)
}

Let's look at it's frames in factorial(4):

[Stack f(4, 5): Stack f(3, 20): [Stack f(2,60): [Stack f(1, 120): 120]]]]

See the differences? In "regular" recursive calls the return functions recursively compose the final value. In Tail Recursion they only reference the base case (last one evaluated). We call accumulator the argument that keeps track of the older values.

Recursion Templates

The regular recursive function go as follows:

type regular(n)
    base_case
    computation
    return (result of computation) combined with (regular(n towards base case))

To transform it in a Tail recursion we:

  • Introduce a helper function that carries the accumulator
  • run the helper function inside the main function, with the accumulator set to the base case.

Look:

type tail(n):
    type helper(n, accumulator):
        if n == base case
            return accumulator
        computation
        accumulator = computation combined with accumulator
        return helper(n towards base case, accumulator)
    helper(n, base case)

See the difference?

Tail Call optimization

Since no state is being stored on the Non-Border-Cases of the Tail Call stacks, they aren't so important. Some languages/interpreters then substitute the old stack with the new one. So, with no stack frames constraining the number of calls, the Tail Calls behave just like a for-loop in these cases.

It's up to your compiler to optimize it, or no.


Here is a simple example that shows how recursive functions work:

long f (long n)
{

    if (n == 0) // have we reached the bottom of the ocean ?
        return 0;

    // code executed in the descendence

    return f(n-1) + 1; // recurrence

    // code executed in the ascendence

}

Tail recursion is a simple recursive function, where recurrence is done at the end of the function, thus no code is done in ascendence, which helps most compilers of high-level programming languages to do what is known as Tail Recursion Optimization, also has a more complex optimization known as the Tail recursion modulo