Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Passing NumPy arrays as arguments to numba.cfunc

I have been fighting with an issue that I am having trouble wrapping my head around, and therefore don't quite know how to start solving it. My experience in programming C is very limited and that is, I think, the reason for which I cannot make progress.

I have some function that uses numpy.interp and scipy.integrate.quad to carry out a certain integral. Since I use quad for the integration, and according to its documentation:

A Python function or method to integrate. If func takes many arguments, it is integrated along the axis corresponding to the first argument.

If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:

double func(double x)
double func(double x, void *user_data) 
double func(int n, double *xx) 
double func(int n, double *xx, void *user_data) 

The user_data is the data contained in the scipy.LowLevelCallable. In the call forms with xx, n is the length of the xx array which contains xx[0] == x and the rest of the items are numbers contained in the args argument of quad.

In addition, certain ctypes call signatures are supported for backward compatibility, but those should not be used in new code.

I need to use the scipy.LowLevelCallable objects for speeding up my code, and I need to stick my function design to one of the above signatures. Moreover, since I do not want to be complicating the whole thing with C libraries and compilers, I want to solve this "on the fly" with the tools available from numba, in particular numba.cfunc, which allows me to by-pass the Python C API.

I have been able to solve this for an integrand that takes as an input the integration variable and an arbitrary number of scalar parameters:

    from scipy import integrate, LowLevelCallable
    from numba import njit, cfunc
    from numba.types import intc, float64, CPointer


    def jit_integrand_function(integrand_function):
        jitted_function = njit(integrand_function)

        @cfunc(float64(intc, CPointer(float64)))
        def wrapped(n, xx):
            return jitted_function(xx[0], xx[1], xx[2], xx[3])
        return LowLevelCallable(wrapped.ctypes)

    @jit_integrand_function
    def regular_function(x1, x2, x3, x4):
        return x1 + x2 + x3 + x4

    def do_integrate_wo_arrays(a, b, c, lolim=0, hilim=1):
        return integrate.quad(regular_function, lolim, hilim, (a, b, c))

    >>> print(do_integrate_wo_arrays(1,2,3,lolim=2, hilim=10))
    (96.0, 1.0658141036401503e-12)

This code works just fine. I am able to jit the integrand function and return the jitted function as a LowLevelCallable object. However, I actually need to pass to my integrand two numpy.arrays, and the above construction breaks:

    from scipy import integrate, LowLevelCallable
    from numba import njit, cfunc
    from numba.types import intc, float64, CPointer


    def jit_integrand_function(integrand_function):
        jitted_function = njit(integrand_function)

        @cfunc(float64(intc, CPointer(float64)))
        def wrapped(n, xx):
            return jitted_function(xx[0], xx[1], xx[2], xx[3])
        return LowLevelCallable(wrapped.ctypes)

    @jit_integrand_function
    def function_using_arrays(x1, x2, array1, array2):
        res1 = np.interp(x1, array1[0], array1[1])
        res2 = np.interp(x2, array2[0], array2[1])

        return res1 + res2

    def do_integrate_w_arrays(a, lolim=0, hilim=1):
        foo = np.arange(20, dtype=np.float).reshape(2, -1)
        bar = np.arange(60, dtype=np.float).reshape(2, -1)

        return integrate.quad(function_using_arrays, lolim, hilim, (a, foo, bar))


    >>> print(do_integrate_w_arrays(3, lolim=2, hilim=10))
    Traceback (most recent call last):
      File "C:\ProgramData\Miniconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3267, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "<ipython-input-63-69c0074d4936>", line 1, in <module>
        runfile('C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py', wdir='C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec')
      File "C:\Program Files\JetBrains\PyCharm Community Edition 2018.3.4\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
        pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
      File "C:\Program Files\JetBrains\PyCharm Community Edition 2018.3.4\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py", line 29, in <module>
        @jit_integrand_function
      File "C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py", line 13, in jit_integrand_function
        @cfunc(float64(intc, CPointer(float64)))
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\decorators.py", line 260, in wrapper
        res.compile()
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
        return func(*args, **kwargs)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\ccallback.py", line 69, in compile
        cres = self._compile_uncached()
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\ccallback.py", line 82, in _compile_uncached
        cres = self._compiler.compile(sig.args, sig.return_type)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\dispatcher.py", line 81, in compile
        raise retval
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\dispatcher.py", line 91, in _compile_cached
        retval = self._compile_core(args, return_type)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\dispatcher.py", line 109, in _compile_core
        pipeline_class=self.pipeline_class)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler.py", line 528, in compile_extra
        return pipeline.compile_extra(func)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler.py", line 326, in compile_extra
        return self._compile_bytecode()
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler.py", line 385, in _compile_bytecode
        return self._compile_core()
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler.py", line 365, in _compile_core
        raise e
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler.py", line 356, in _compile_core
        pm.run(self.state)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_machinery.py", line 328, in run
        raise patched_exception
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_machinery.py", line 319, in run
        self._runPass(idx, pass_inst, state)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
        return func(*args, **kwargs)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_machinery.py", line 281, in _runPass
        mutated |= check(pss.run_pass, internal_state)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\compiler_machinery.py", line 268, in check
        mangled = func(compiler_state)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\typed_passes.py", line 94, in run_pass
        state.locals)
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\typed_passes.py", line 66, in type_inference_stage
        infer.propagate()
      File "C:\Users\mosegui\AppData\Roaming\Python\Python36\site-packages\numba\typeinfer.py", line 951, in propagate
        raise errors[0]
    numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Failed in nopython mode pipeline (step: nopython frontend)
    Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (float64, Literal[int](0))
     * parameterized
    In definition 0:
        All templates rejected with literals.
    In definition 1:
        All templates rejected without literals.
    In definition 2:
        All templates rejected with literals.
    In definition 3:
        All templates rejected without literals.
    In definition 4:
        All templates rejected with literals.
    In definition 5:
        All templates rejected without literals.
    In definition 6:
        All templates rejected with literals.
    In definition 7:
        All templates rejected without literals.
    In definition 8:
        All templates rejected with literals.
    In definition 9:
        All templates rejected without literals.
    In definition 10:
        All templates rejected with literals.
    In definition 11:
        All templates rejected without literals.
    In definition 12:
        All templates rejected with literals.
    In definition 13:
        All templates rejected without literals.
    This error is usually caused by passing an argument of a type that is unsupported by the named function.
    [1] During: typing of intrinsic-call at C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py (32)
    [2] During: typing of static-get-item at C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py (32)
    File "test_scipy_numba.py", line 32:
    def diff_moment_edge(radius, alpha, chord_df, aerodyn_df):
        <source elided>
        # # calculate blade twist for radius
        # sensor_twist = np.arctan((2 * rated_wind_speed) / (3 * rated_rotor_speed * (sensor_radius / 30.0) * radius)) * (180.0 / np.pi)
        ^
    [1] During: resolving callee type: type(CPUDispatcher(<function function_using_arrays at 0x0000020C811827B8>))
    [2] During: typing of call at C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py (15)
    [3] During: resolving callee type: type(CPUDispatcher(<function function_using_arrays at 0x0000020C811827B8>))
    [4] During: typing of call at C:/Users/mosegui/Desktop/fos4x_pkg_develop/python-packages/fos4x_tec/fos4x_tec/test_scipy_numba.py (15)
    File "test_scipy_numba.py", line 15:
    def jit_integrand_function(integrand_function):
        <source elided>
        jitted_function = njit(integrand_function)
     ^

Well now, obviously this doesn't work, because in this example I have not modified the design of the decorator. But that is exactly the core of my question: I do not fully understand this situation and therefore don't know how to modify the cfunc arguments for passing an array as parameter and still complying with the scipy.integrate.quad signature requirements. In the numba documentation that introduces CPointers there is an example of how to pass an array to a numba.cfunc:

Native platform ABIs as used by C or C++ don’t have the notion of a shaped array as in Numpy. One common solution is to pass a raw data pointer and one or several size arguments (depending on dimensionality). Numba must provide a way to rebuild an array view of this data inside the callback.

    from numba import cfunc, carray
    from numba.types import float64, CPointer, void, intp

    # A callback with the C signature `void(double *, double *, size_t)`

    @cfunc(void(CPointer(float64), CPointer(float64), intp))
    def invert(in_ptr, out_ptr, n):
        in_ = carray(in_ptr, (n,))
        out = carray(out_ptr, (n,))
        for i in range(n):
            out[i] = 1 / in_[i] ```

I somehow understand that the CPointer is used for building the array in C, as in the signature of my decorator example CPointer(float64) gathers all the floats passed as arguments and puts them into an array. However, I am still not able to put it together and see how I can use it for passing an array, not for making an array out of the collection of float arguments I pass.

EDIT:

The answer by @max9111 worked, in the sense that was able to pass to the scipy.integrate.quad a LowLevelCallable that improved the time efficiency of the calculation. This is very valuable, as now it is much clearer how management of memory addresses work in C. Even though the concept of structured array does not exist in native C, I can create a structured array in Python with data that C is going to store in a contiguous memory region and access it/point to it through a unique memory address. The mapping provided by the structured array allows for identifying the different components of that memory region.

Even though the solution by @max9111 works and solves the question I posted originally, from the Python perspective this approach introduces a certain overhead that, under certain conditions, might be more time consuming than the time gained by now calling the scipy.integrate.quad integration function via a LowLevelCallable.

In my real case I am using the integration as a step of a two-dimensional optimization problem. Each step of the optimization requires to integrate twice, and the integral requires nine scalar parameters ad two arrays. As long as I have not been able to solve the integration via a LowLevelCallable, the only thing I could do to speed up the code was to simply njit the integrand function. And that worked decently even though the integration was still being fired via the Python API.

In my case, implementing @max9111's solution has greatly improved the efficiency in the integration time (from about 0.0009s per step to about 0.0005 ). Nevertheless, the step of creating the structured array, C-unpacking the data, pass it to the jitted integrand and return a LowLevelCallable has added an extra 0.3s on average per iteration, thereby worsening my situation.

Here there is some toy code for showing how the LowLevelCallable approach becomes the worse option the more one goes for an iterative process:

    import ctypes
    import timeit

    from tqdm import tqdm
    import numpy as np
    from scipy import integrate, LowLevelCallable
    import numba as nb
    from numba import types
    import matplotlib.pyplot as plt


    ##################################################
    # creating some sample data and parameters
    a = 3
    foo = np.arange(200, dtype=np.float64).reshape(2, -1)
    bar = np.arange(600, dtype=np.float64).reshape(2, -1)

    lim1 = 0
    lim2 = 1

    @nb.njit
    def function_using_arrays(x1, x2, array1, array2):
        res1 = np.interp(x1, array1[0], array1[1])
        res2 = np.interp(x2, array2[0], array2[1])

        return res1 + res2


    ##################################################
    # JIT INTEGRAND

    def do_integrate_w_arrays_jit(a, array1, array2, lolim=0, hilim=1):
        return integrate.quad(function_using_arrays, lolim, hilim, (a, array1, array2))

    def process_jit_integrand():
        do_integrate_w_arrays_jit(a, foo, bar, lolim=lim1, hilim=lim2)


    ##################################################
    # LOWLEV CALLABLE

    def create_jit_integrand_function(integrand_function,args,args_dtype):
        @nb.cfunc(types.float64(types.float64,types.CPointer(args_dtype)))
        def wrapped(x1,user_data_p):
            #Array of structs
            user_data = nb.carray(user_data_p, 1)

            #Extract the data
            x2=user_data[0].a
            array1=user_data[0].foo
            array2=user_data[0].bar

            return integrand_function(x1, x2, array1, array2)
        return wrapped


    def do_integrate_w_arrays_lowlev(func,args,lolim=0, hilim=1):
        integrand_func = LowLevelCallable(func.ctypes,user_data=args.ctypes.data_as(ctypes.c_void_p))
        return integrate.quad(integrand_func, lolim, hilim)


    def process_lowlev_callable():
        args_dtype = types.Record.make_c_struct([
            ('a', types.float64),
            ('foo', types.NestedArray(dtype=types.float64, shape=foo.shape)),
            ('bar', types.NestedArray(dtype=types.float64, shape=bar.shape)),])

        args=np.array((a, foo, bar), dtype=args_dtype)

        func = create_jit_integrand_function(function_using_arrays,args,args_dtype)

        do_integrate_w_arrays_lowlev(func, args, lolim=0, hilim=1)


    ##################################################

    repetitions = range(100)

    jit_integrand_delays = [timeit.timeit(stmt=process_jit_integrand, 
    number=repetition) for repetition in tqdm(repetitions)]
    lowlev_callable_delays = [timeit.timeit(stmt=process_lowlev_callable, 
    number=repetition) for repetition in tqdm(repetitions)]

    fig, ax = plt.subplots()
    ax.plot(repetitions, jit_integrand_delays, label="jit_integrand")
    ax.plot(repetitions, lowlev_callable_delays, label="lowlev_callable")
    ax.set_xlabel('number of repetitions')
    ax.set_ylabel('calculation time (s)')
    ax.set_title("Comparison calculation time")
    plt.tight_layout()
    plt.legend()
    plt.savefig(f'calculation_time_comparison_{repetitions[-1]}_reps.png')

Here the two options (only jitting the integrand vs. @max9111's solution) are compared. In a modified version of @max9111's solution, I have permanently jitted the integrand function (function_using_arrays) and removed that step from create_jit_integrand_function, which reduces the "overhead" time by a neat 20%. Moreover, also for the sake of speed, I have suppressed the jit_with_dummy_data function and included its functionality in the body of process_lowlev_callable, basically for avoiding an unnecessary function call. Find in the following the calculation time for both solutions for a series of up to 100 cycles:

Calculation time ellapsed for only jitting the integrand vs building a LowLevelCallable

As you can see, if you are in an iterative process, the time saved in each single calculation (30+ %!!) does not pay off the overhead carried by the couple extra functions that you need to implement for building the LowLevelCallable (functions that are as well called iteratively and run over the Python C API).

Bottom line: this solution is very good for reducing calculation time in one single very heavy integral, but just jitting the integrand seems to be better off when solving average integrals within an iterative process, since the extra functions required by the LowlevelCallable, that need to be called as much often as the integration itself, take their toll.

Anyway, thank you very much. Even though this solution will not work for me, I learned valuable things and I consider my question solved.

EDIT 2:

I misunderstood parts of @max9111's solution and the role played by the function create_jit_integrand_function, and I was wrongly compiling the LowLevelCallable in each step of my optimization (which I do not need to do beacuse, even though the parameters and arrays passed to the integral change each iteration, their shape, and therefore that of the C struct remains constant).

The refactored version of the code from the above EDIT that makes sense:


    import ctypes
    import timeit

    from tqdm import tqdm
    import numpy as np
    from scipy import integrate, LowLevelCallable
    import numba as nb
    from numba import types
    import matplotlib.pyplot as plt


    ##################################################
    # creating some sample data and parameters
    a = 3
    foo = np.arange(200, dtype=np.float64).reshape(2, -1)
    bar = np.arange(600, dtype=np.float64).reshape(2, -1)

    lim1 = 0
    lim2 = 1


    def function_using_arrays(x1, x2, array1, array2):
        res1 = np.interp(x1, array1[0], array1[1])
        res2 = np.interp(x2, array2[0], array2[1])

        return res1 + res2


    ##################################################
    # JIT INTEGRAND

    def do_integrate_w_arrays_jit(a, array1, array2, lolim=0, hilim=1):
        return integrate.quad(nb.njit(function_using_arrays), lolim, hilim, (a, array1, array2))

    def process_jit_integrand():
        do_integrate_w_arrays_jit(a, foo, bar, lolim=lim1, hilim=lim2)


    ##################################################
    # LOWLEV CALLABLE

    def create_jit_integrand_function(integrand_function, args_dtype):
        jitted_function = nb.njit(integrand_function)

        @nb.cfunc(types.float64(types.float64,types.CPointer(args_dtype)))
        def wrapped(x1,user_data_p):
            #Array of structs
            user_data = nb.carray(user_data_p, 1)

            #Extract the data
            x2=user_data[0].a
            array1=user_data[0].foo
            array2=user_data[0].bar

            return jitted_function(x1, x2, array1, array2)
        return wrapped


    def do_integrate_w_arrays_lowlev(func,args,lolim=0, hilim=1):
        integrand_func=LowLevelCallable(func.ctypes,user_data=args.ctypes.data_as(ctypes.c_void_p))
        return integrate.quad(integrand_func, lolim, hilim)


    def process_lowlev_callable():
        do_integrate_w_arrays_lowlev(func, np.array((a, foo, bar), dtype=args_dtype), lolim=0, hilim=1)


    ##################################################

    repetitions = range(100)

    jit_integrand_delays = [timeit.timeit(stmt=process_jit_integrand, number=repetition) for repetition in tqdm(repetitions)]


    args_dtype = types.Record.make_c_struct([
            ('a', types.float64),
            ('foo', types.NestedArray(dtype=types.float64, shape=foo.shape)),
            ('bar', types.NestedArray(dtype=types.float64, shape=bar.shape)),])
    func = create_jit_integrand_function(function_using_arrays, args_dtype)


    lowlev_callable_delays = [timeit.timeit(stmt=process_lowlev_callable, number=repetition) for repetition in tqdm(repetitions)]

    fig, ax = plt.subplots()
    ax.plot(repetitions, jit_integrand_delays, label="jit_integrand")
    ax.plot(repetitions, lowlev_callable_delays, label="lowlev_callable")
    ax.set_xlabel('number of repetitions')
    ax.set_ylabel('calculation time (s)')
    ax.set_title("Comparison calculation time")
    plt.tight_layout()
    plt.legend()
    plt.savefig(f'calculation_time_comparison_{repetitions[-1]}_reps.png')

In this configuration, the building of the LowLevelCallable (which indeed costs a bit of time) only has to be carried out once, and the overall process is orders of magnitude faster:

enter image description here

and the close-up for the lowlev_callable:

enter image description here

like image 561
mosegui Avatar asked Oct 15 '22 10:10

mosegui


1 Answers

You can use the user_data Input to pass arrays

As I understood the documentation of scipy.integrate.quad it isn't possible to pass arrays with the args parameter when using a scipy.LowLevelCallable, but you can pass abitrary user_data.

In the following example I used this signature.

double func(double x, void *user_data)

Edit Arbitrary shapes of arrays without recompilation

Using this answer it is also possible to compile the function once for arbitrary array shapes (only the number of dimensions is fixed).

import numpy as np
import numba as nb
from numba import types
from scipy import integrate, LowLevelCallable
import ctypes

#Void Pointer from Int64
@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba import types 
    from numba.core import cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

def create_jit_integrand_function(integrand_function,args_dtype):
    jitted_function = nb.njit(integrand_function)

    #double func(double x, void *user_data)
    @nb.cfunc(types.float64(types.float64,types.CPointer(args_dtype)))
    def wrapped(x1,user_data_p):
        #Array of structs
        user_data = nb.carray(user_data_p, 1)

        #Extract the data
        x2=user_data[0].a
        array1=nb.carray(address_as_void_pointer(user_data[0].foo_p),(user_data[0].foo_s1,user_data[0].foo_s2),dtype=np.float64)
        array2=nb.carray(address_as_void_pointer(user_data[0].bar_p),(user_data[0].bar_s1,user_data[0].bar_s2),dtype=np.float64)

        return jitted_function(x1, x2, array1, array2)
    return wrapped

def function_using_arrays(x1, x2, array1, array2):
    res1 = np.interp(x1, array1[0], array1[1])
    res2 = np.interp(x2, array2[0], array2[1])

    return res1 + res2

def do_integrate_w_arrays(func,args,lolim=0, hilim=1):
    integrand_func=LowLevelCallable(func.ctypes,user_data=args.ctypes.data_as(ctypes.c_void_p))
    return integrate.quad(integrand_func, lolim, hilim)

#Define the datatype of the struct array
#Pointers are not allowed, therefore we use int64
args_dtype = types.Record.make_c_struct([
    ('a', types.float64),
    ('foo_p',  types.int64),
    ('foo_s1', types.int64),
    ('foo_s2', types.int64),
    ('bar_p',  types.int64),
    ('bar_s1', types.int64),
    ('bar_s2', types.int64),])

#creating some sample data
#The arrays must be c-contigous
#To ensure that you can use np.ascontiguousarray

a=3
foo = np.ascontiguousarray(np.arange(200, dtype=np.float64).reshape(2, -1))
bar = np.ascontiguousarray(np.arange(600, dtype=np.float64).reshape(2, -1))


args=np.array((a,foo.ctypes.data,foo.shape[0],foo.shape[1],
                 bar.ctypes.data,bar.shape[0],bar.shape[1]),dtype=args_dtype)

#compile the integration function (array-shapes are fixed)
#There is only a structured array like args allowed
func=create_jit_integrand_function(function_using_arrays,args_dtype)


print(do_integrate_w_arrays(func,args, lolim=0, hilim=1))

Old version

As I am passing a Structured array a recompilation is needed if the array shapes or datatypes changes. This isn't a limitation of the API itself. There must a way how to do this in an easier way (Maybe using Tuples?)

Implementation

import numpy as np
import numba as nb
from numba import types
from scipy import integrate, LowLevelCallable
import ctypes

def create_jit_integrand_function(integrand_function,args,args_dtype):
    jitted_function = nb.njit(integrand_function)

    @nb.cfunc(types.float64(types.float64,types.CPointer(args_dtype)))
    def wrapped(x1,user_data_p):
        #Array of structs
        user_data = nb.carray(user_data_p, 1)

        #Extract the data
        x2=user_data[0].a
        array1=user_data[0].foo
        array2=user_data[0].bar

        return jitted_function(x1, x2, array1, array2)
    return wrapped

def function_using_arrays(x1, x2, array1, array2):
    res1 = np.interp(x1, array1[0], array1[1])
    res2 = np.interp(x2, array2[0], array2[1])

    return res1 + res2

def jit_with_dummy_data(args,args_dtype):
    func=create_jit_integrand_function(function_using_arrays,args,args_dtype)
    return func

def do_integrate_w_arrays(func,args,lolim=0, hilim=1):
    integrand_func=LowLevelCallable(func.ctypes,user_data=args.ctypes.data_as(ctypes.c_void_p))
    return integrate.quad(integrand_func, lolim, hilim)

Using the Implementation

#creating some sample data
a=3
foo = np.arange(200, dtype=np.float64).reshape(2, -1)
bar = np.arange(600, dtype=np.float64).reshape(2, -1)

args_dtype = types.Record.make_c_struct([
    ('a', types.float64),
    ('foo', types.NestedArray(dtype=types.float64, shape=foo.shape)),
    ('bar', types.NestedArray(dtype=types.float64, shape=bar.shape)),])

args=np.array((a,foo,bar),dtype=args_dtype)

#compile the integration function (array-shapes are fixed)
#There is only a structured array like args allowed
func=jit_with_dummy_data(args,args_dtype)


print(do_integrate_w_arrays(func,args, lolim=0, hilim=1))
like image 181
max9111 Avatar answered Oct 19 '22 22:10

max9111