Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python Numba/jit conditional and recursive (stack) use

Tags:

python

jit

numba

All,

I'm using numba JIT to speed up my Python code, but the code should be functional even if numba & LLVM are not installed.

My first idea was to do this as follows:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False

def run_it(parameters):
    # do something
    pass

# define wrapper call function with optimizer
@jit
def run_it_with_numba(parameters):
    return run_it(parameters)

# [...]
# main program 
t_start = timeit.default_timer()

# this is the code I don't like 
if use_numba:
    res = run_it_with_numba(parameters)
else:
    res = run_it(parameters)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

This does not work as I had expected, because the compilation seems to apply only on the run_it_with_numba() function -which basically does nothing- but not on the subroutines called from that function.

The results only get better when I apply @jit on the function that contains the workload.

Is there a chance to avoid the wrapper function and the if-clause in the main program?

Is there a way to tell to Numba to optimize also the subroutines that are called from my entry function? Because run_it() also contains some function calls and I expected @jit to deal with that.

cu, Ale

like image 315
Ale Avatar asked Mar 17 '23 11:03

Ale


2 Answers

You can provide a do-nothing version of jit in the case Numba is not installed:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False
    from _shim import jit, int32

@jit
def run_it(parameters):
    # do something
    pass

# [...]
# main program 
t_start = timeit.default_timer()

res = run_it(eval(row[0]), workfeed, instrument)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

Where _shim.py just contains:

def jit(*args, **kwargs):
    def wrapper(f):
        return f
    if len(args) > 0 and (args[0] is marker or not callable(args[0])) \
        or len(kwargs) > 0:
        # @jit(int32(int32, int32)), @jit(signature="void(int32)")
        return wrapper
    elif len(args) == 0:
        # @jit()
        return wrapper
    else:
        # @jit
        return args[0]

def marker(*args, **kwargs): return marker

int32 = marker
like image 184
Sean Vieira Avatar answered Apr 01 '23 18:04

Sean Vieira


I think you want to do this in a different way. Instead of wrapping the method, just optionally alias it. For example using an dummy method to allow actual timings:

import numpy as np
import timeit 

use_numba = False
try:
    import numba as nb
except ImportError, e:
    use_numba = False

def _run_it(a, N):
    s = 0.0
    for k in xrange(N):
        s += k / np.sin(a)

    return s

# define wrapper call function with optimizer
if use_numba:
    print 'Using numba'
    run_it = nb.jit()(_run_it)
else:
    print 'Falling back to python'
    run_it = _run_it

if __name__ == '__main__':
    print timeit.repeat('run_it(50.0, 100000)', setup='from __main__ import run_it', repeat=3, number=100)

Running this with the use_numba flag as True:

$ python nbtest.py
Using numba
[0.18746304512023926, 0.15185213088989258, 0.1636970043182373]

and as False:

$ python nbtest.py
Falling back to python
[9.707707166671753, 9.779848098754883, 9.770231008529663]

or in the iPython notebook using the nice %timeit magic:

run_it_numba = nb.jit()(_run_it)

%timeit _run_it(50.0, 10000)
100 loops, best of 3: 9.51 ms per loop

%timeit run_it_numba(50.0, 10000)  
10000 loops, best of 3: 144 µs per loop

Note that when timing numba methods, timing a single execution of the method will take into account the time it takes numba to jit the method. All subsequent runs will be much faster.

like image 40
JoshAdel Avatar answered Apr 01 '23 18:04

JoshAdel