Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is tail recursion optimization faster than normal recursion in Python?

While I understand that tail recursion optimization is non-Pythonic, I came up with a quick hack to a question on here that was deleted as soon as a I was ready to post.

With a 1000 stack limit, deep recursion algorithms are not usable in Python. But sometimes it is great for initial thoughts through a solution. Since functions are first class in Python, I played with returning a valid function and the next value. Then call the process in a loop until done with single calls. I'm sure this isn't new.

What I found interesting is that I expected the extra overhead of the passing the function back and forth to make this slower than normal recursion. During my crude testing I found it to take 30-50% the time of normal recursion. (With an added bonus of allowing LONG recursions.)

Here is the code I'm running:

from contextlib import contextmanager
import time

# Timing code from StackOverflow most likely.
@contextmanager
def time_block(label):
    start = time.clock()
    try:
        yield
    finally:
        end = time.clock()
        print ('{} : {}'.format(label, end - start))


# Purely Recursive Function
def find_zero(num):
    if num == 0:
        return num
    return find_zero(num - 1)


# Function that returns tuple of [method], [call value]
def find_zero_tail(num):
    if num == 0:
        return None, num
    return find_zero_tail, num - 1


# Iterative recurser
def tail_optimize(method, val):
    while method:
        method, val = method(val)
    return val


with time_block('Pure recursion: 998'):
    find_zero(998)

with time_block('Tail Optimize Hack: 998'):
    tail_optimize(find_zero_tail, 998)

with time_block('Tail Optimize Hack: 1000000'):
    tail_optimize(find_zero_tail, 10000000)

# One Run Result:
# Pure recursion: 998 : 0.000372791020758
# Tail Optimize Hack: 998 : 0.000163852100569
# Tail Optimize Hack: 1000000 : 1.51006975627

Why is the second style faster?

My guess is the overhead with creating entries on the stack, but I'm not sure how to find out.

Edit:

In playing with call counts, I made a loop to try both at various num values. Recursive was much closer to parity when I was looping and calling multiple times.

So, I adding this before the timing, which is find_zero under a new name:

def unrelated_recursion(num):
    if num == 0:
        return num
    return unrelated_recursion(num - 1)

unrelated_recursion(998)

Now the tail optimized call is 85% of the time of the full recursion.

So my theory is that 15% penalty is the overhead for the larger stack, versus single stack.

The reason I saw such a huge disparity in execution time when only running each once was the penalty for allocation of the stack memory and structure. Once that is allocated, the cost of using them is drastically lowered.

Because my algorithm is dead simple, the memory structure allocation is a large portion of the execution time.

When I cut my stack priming call to unrelated_recursion(499), I get about half way between fully primed and not primed stack in find_zero(998) execution time. This makes sense with the theory.

like image 808
Joe Avatar asked May 12 '16 16:05

Joe


People also ask

Is tail recursion faster than normal recursion?

As a rule of thumb; tail-recursive functions are faster if they don't need to reverse the result before returning it. That's because that requires another iteration over the whole list. Tail-recursive functions are usually faster at reducing lists, like our first example.

Why is a tail-recursive function more efficient compared to normal recursive function?

The tail recursion is better than non-tail recursion. As there is no task left after the recursive call, it will be easier for the compiler to optimize the code. When one function is called, its address is stored inside the stack. So if it is tail recursion, then storing addresses into stack is not needed.

Why tail recursion is more efficient?

In tail recursion, there is no other operation to perform after executing the recursive function itself; the function can directly return the result of the recursive call. In simple words, in tail recursion, the recursive function is called last. So it is more efficient than non-tail recursion.

Is tail recursion optimized in Python?

There is no built-in tail recursion optimization in Python.


Video Answer


1 Answers

As a comment hopefully remineded me, I was not really answering the question, so here is my sentiment:

In your optimization, you're allocating, unpacking and deallocating tuples, so I tried without them:

# Function that returns tuple of [method], [call value]
def find_zero_tail(num):
    if num == 0:
        return None
    return num - 1


# Iterative recurser
def tail_optimize(method, val):
    while val:
        val = method(val)
    return val

for 1000 tries, each starting with value = 998:

  • this version take 0.16s
  • your "optimized" version took 0.22s
  • the "unoptimized" one took 0.29s

(Note that for me, your optimized version is faster that the un-optimized one ... but we don't do the exact same test.)

But I don't think this is usefull to get those stats: cost is more on the side of Python (methods calls, tuples allocations, ...) that your code doing real things. In a real application you'll not end up measuring the cost of 1000 tuples, but the cost of your actual implementation.

But simply don't do this: this is just hard to read for almost nothing, you're writing for the reader, not for the machine:

# Function that returns tuple of [method], [call value]
def find_zero_tail(num):
    if num == 0:
        return None, num
    return find_zero_tail, num - 1


# Iterative recurser
def tail_optimize(method, val):
    while method:
        method, val = method(val)
    return val

I won't try to implement a more readable version of it because I'll probably end up with:

def find_zero(val):
    return 0

But I think in real cases there's some nice ways to deal with recursion limits (both on memory size or depth side):

To help about memory (not depth), an lru_cache from functools may typically help a lot:

>>> from functools import lru_cache
>>> @lru_cache()
... def fib(x):
...     return fib(x - 1) + fib(x - 2) if x > 2 else 1
... 
>>> fib(100)
354224848179261915075

And for stack size, you may use a list or a deque, depending on your context and usage, instead of using the language stack. Depending on the exact implementation (when you're in fact storing simple sub-computation in your stack to re-use them) it's called dynamic programming:

>>> def fib(x):
...     stack = [1, 1]
...     while len(stack) < x:
...         stack.append(stack[-1] + stack[-2])
...     return stack[-1]
... 
>>> fib(100)
354224848179261915075

But, and here comes the nice part of using your own structure instead of the call stack, you're not always needed to keep the whole stack to continue computations:

>>> def fib(x):
...     stack = (1, 1)
...     for _ in range(x - 2):
...         stack = stack[1], stack[0] + stack[1]
...     return stack[1]
... 
>>> fib(100)
354224848179261915075

But to conclude with a nice touch of "know the problem before trying to implement it" (unreadable, hard to debug, hard to visually proove, it's bad code, but it's fun):

>>> def fib(n):
...     return (4 << n*(3+n)) // ((4 << 2*n) - (2 << n) - 1) & ((2 << n) - 1)
... 
>>> 
>>> fib(99)
354224848179261915075

If you ask me, the best implementation is the more readable one (for the Fibonacci example, probably the one with an LRU cache but by changing the ... if ... else ... with a more readable if statement, for another example a deque may be more readable, and for other examples, dynamic programming may be better...

"You're writing for the human reading your code, not for the machine".

like image 51
Julien Palard Avatar answered Sep 24 '22 03:09

Julien Palard