Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using a function object as an argument for numba njit function

Tags:

python

numba

I want to make a general function, which takes a function object as an argument.

One of the simplest cases:

import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
    return f(a)

test(np.arange(10), np.mean)

gives error, although test(np.arange(10)) works as expected.

The error:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
[1] During: typing of argument at <ipython-input-54-52cead0f097d> (5)

File "<ipython-input-54-52cead0f097d>", line 5:
def test(a, f=np.median):
    return f(a)
    ^

This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'function'>

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

Is this not allowed or am I missing something?

like image 562
ysBach Avatar asked Mar 03 '23 23:03

ysBach


1 Answers

Using functions as arguments is tricky with numba and quite expensive. This is mentioned in the Frequently Asked Questions: "1.18.1.1. Can I pass a function as an argument to a jitted function?":

1.18.1.1. Can I pass a function as an argument to a jitted function?

As of Numba 0.39, you can, so long as the function argument has also been JIT-compiled:

@jit(nopython=True)
def f(g, x):
    return g(x) + g(-x)
result = f(jitted_g_function, 1)

However, dispatching with arguments that are functions has extra overhead. If this matters for your application, you can also use a factory function to capture the function argument in a closure:

def make_f(g):
    # Note: a new f() is created each time make_f() is called!
    @jit(nopython=True)
    def f(x):
        return g(x) + g(-x)
    return f
f = make_f(jitted_g_function)
result = f(1)

Improving the dispatch performance of functions in Numba is an ongoing task.

This means you have the option to use a function factory:

import numpy as np
import numba as nb

def test(a, func=np.median):
    @nb.njit
    def _test(a):
        return func(a)
    return _test(a)

>>> test(np.arange(10))
4.5
>>> test(np.arange(10), np.min)
0
>>> test(np.arange(10), np.mean)
4.5

Or to wrap the function argument as jitted-function before passing it in as argument:

import numpy as np
import numba as nb

@nb.njit()
def test(a, f=np.median):
    return f(a)

@nb.njit
def wrapped_mean(a):
    return np.mean(a)

@nb.njit
def wrapped_median(a):
    return np.median(a)

>>> test(np.arange(10))
4.5
>>> test(np.arange(10), wrapped_mean)
4.5
>>> test(np.arange(10), wrapped_median)
4.5

Both options have quite a bit of boilerplate and aren't as straight-forward as one might hope.

The function-factory approach also repeatedly creates and compiles functions, so if you often call it with the same function as argument you could use a dictionary to store the known compiled functions:

import numpy as np
import numba as nb

_precompiled_funcs = {}

def test(a, func=np.median):
    if func not in _precompiled_funcs:
        @nb.njit
        def _test(arr):
            return func(arr)
        result = _test(a)
        _precompiled_funcs[func] = _test
        return result
    return _precompiled_funcs[func](a)

The other approach (using the wrapped and jitted functions) also has some overhead, however it's not really noticeable as long as the arrays you pass in have a significant number of elements (>1000).

If the function you've showed is really the function that you wanted to use I wouldn't use numba on it at all. With such simple tasks that don't exercise the strength of numba (indexing and iterating arrays or heavy number crunching) using Python + NumPy should be faster (or as fast) and much easier to debug and understand:

import numba as nb

def test(a, f=np.median):
    return f(a)
like image 172
MSeifert Avatar answered Mar 05 '23 15:03

MSeifert