Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how a decorator works when the argument is recursive function?

import time
def clock(func):
   def clocked(*args):
       t0 = time.perf_counter()
       result = func(*args) 
       elapsed = time.perf_counter() - t0
       name = func.__name__
       arg_str = ', '.join(repr(arg) for arg in args)
       print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
       return result
   return clocked

this is the decorator.

@clock
def factorial(n):
    return 1 if n < 2 else n*factorial(n-1)

the parts of result is:

[0.00000191s] factorial(1) -> 1
[0.00004911s] factorial(2) -> 2
[0.00008488s] factorial(3) -> 6
[0.00013208s] factorial(4) -> 24
[0.00019193s] factorial(5) -> 120
[0.00026107s] factorial(6) -> 720
6! = 720

how this decorator works when the argument is recursive function? why the decorator can be executed for many times. how it works?

like image 275
Dean Wang Avatar asked Mar 02 '26 20:03

Dean Wang


1 Answers

In your example, the clock decorator is executed once, when it replaces the original version of factorial with the clocked version. The original factorial is recursive and therefore the decorated version is recursive too. And so you get the timing data printed for each recursive call - the decorated factorial calls itself, not the original version, because the name factorial now refers to the decorated version.


It's a good idea to use functools.wraps in decorators. This copies various attributes of the original function to the decorated version.

For example, without wraps:

import time

def clock(func):
    def clocked(*args):
        ''' Clocking decoration wrapper '''
        t0 = time.perf_counter()
        result = func(*args) 
        elapsed = time.perf_counter() - t0
        name = func.__name__
        arg_str = ', '.join(repr(arg) for arg in args)
        print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
        return result
    return clocked

@clock
def factorial(n):
    ''' Recursive factorial '''
    return 1 if n < 2 else n * factorial(n-1)

print(factorial.__name__, factorial.__doc__)

output

clocked  Clocking decoration wrapper 

With wraps:

import time
from functools import wraps

def clock(func):
    @wraps(func)
    def clocked(*args):
        ''' Clocking decoration wrapper '''
        t0 = time.perf_counter()
        result = func(*args) 
        elapsed = time.perf_counter() - t0
        name = func.__name__
        arg_str = ', '.join(repr(arg) for arg in args)
        print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
        return result
    return clocked

@clock
def factorial(n):
    ''' Recursive factorial '''
    return 1 if n < 2 else n * factorial(n-1)

print(factorial.__name__, factorial.__doc__)

output

factorial  Recursive factorial 

which is what we'd get if we did print(factorial.__name__, factorial.__doc__) on the undecorated version.


If you don't want the clock-decorated recursive function to print the timing info for all of the recursive calls, it gets a bit tricky.

The simplest way is to not use the decorator syntax and just call clock as a normal function so we get a new name for the clocked version of the function:

def factorial(n):
    return 1 if n < 2 else n * factorial(n-1)

clocked_factorial = clock(factorial)

for n in range(7):
    print('%d! = %d' % (n, clocked_factorial(n)))

output

[0.00000602s] factorial(0) -> 1
0! = 1
[0.00000302s] factorial(1) -> 1
1! = 1
[0.00000581s] factorial(2) -> 2
2! = 2
[0.00000539s] factorial(3) -> 6
3! = 6
[0.00000651s] factorial(4) -> 24
4! = 24
[0.00000742s] factorial(5) -> 120
5! = 120
[0.00000834s] factorial(6) -> 720
6! = 720

Another way is to wrap the recursive function in a non-recursive function and apply the decorator to the new function.

def factorial(n):
    return 1 if n < 2 else n * factorial(n-1)

@clock
def nr_factorial(n):
    return factorial(n)

for n in range(3, 7):
    print('%d! = %d' % (n, nr_factorial(n)))

output

[0.00001018s] nr_factorial(3) -> 6
3! = 6
[0.00000799s] nr_factorial(4) -> 24
4! = 24
[0.00000801s] nr_factorial(5) -> 120
5! = 120
[0.00000916s] nr_factorial(6) -> 720
6! = 720

Another way is to modify the decorator so that it keeps track of whether it's executing the top level of the recursion or one of the inner levels, and only print the timing info for the top level. This version uses the nonlocal directive so it only works in Python 3, not Python 2.

def rclock(func):
    top = True
    @wraps(func)
    def clocked(*args):
        nonlocal top
        if top:
            top = False
            t0 = time.perf_counter()
            result = func(*args) 
            elapsed = time.perf_counter() - t0
            name = func.__name__
            arg_str = ', '.join(repr(arg) for arg in args)
            print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
        else:
            result = func(*args)
            top = True
        return result
    return clocked

@rclock
def factorial(n):
    return 1 if n < 2 else n * factorial(n-1)

for n in range(3, 7):
    print('%d! = %d' % (n, factorial(n))) 

output

[0.00001253s] factorial(3) -> 6
3! = 6
[0.00001205s] factorial(4) -> 24
4! = 24
[0.00001227s] factorial(5) -> 120
5! = 120
[0.00001422s] factorial(6) -> 720
6! = 720

The rclock function can be used on non-recursive functions, but it's a little more efficient to just use the original version of clock.


Another handy function in functools that you should know about if you're using recursive functions is lru_cache. This keeps a cache of recently computed results so they don't need to be re-computed. This can enormously speed up recursive functions. Please see the docs for details.

We can use lru_cache in conjunction with clock or rclock.

@lru_cache(None)
@clock
def factorial(n):
    return 1 if n < 2 else n * factorial(n-1)

for n in range(3, 7):
    print('%d! = %d' % (n, factorial(n)))

output

[0.00000306s] factorial(1) -> 1
[0.00017850s] factorial(2) -> 2
[0.00022049s] factorial(3) -> 6
3! = 6
[0.00000542s] factorial(4) -> 24
4! = 24
[0.00000417s] factorial(5) -> 120
5! = 120
[0.00000409s] factorial(6) -> 720
6! = 720

As you can see, even though we used the plain clock decorator only a single line of timing info gets printed for the factorials of 4, 5, and 6 because the smaller factorials are read from the cache instead of being re-computed.

like image 161
PM 2Ring Avatar answered Mar 05 '26 09:03

PM 2Ring



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!