I want to make a general function, which takes a function object as an argument.
One of the simplest cases:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
test(np.arange(10), np.mean)
gives error, although test(np.arange(10))
works as expected.
The error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
[1] During: typing of argument at <ipython-input-54-52cead0f097d> (5)
File "<ipython-input-54-52cead0f097d>", line 5:
def test(a, f=np.median):
return f(a)
^
This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'function'>
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
Is this not allowed or am I missing something?
Using functions as arguments is tricky with numba and quite expensive. This is mentioned in the Frequently Asked Questions: "1.18.1.1. Can I pass a function as an argument to a jitted function?":
1.18.1.1. Can I pass a function as an argument to a jitted function?
As of Numba 0.39, you can, so long as the function argument has also been JIT-compiled:
@jit(nopython=True) def f(g, x): return g(x) + g(-x) result = f(jitted_g_function, 1)
However, dispatching with arguments that are functions has extra overhead. If this matters for your application, you can also use a factory function to capture the function argument in a closure:
def make_f(g): # Note: a new f() is created each time make_f() is called! @jit(nopython=True) def f(x): return g(x) + g(-x) return f f = make_f(jitted_g_function) result = f(1)
Improving the dispatch performance of functions in Numba is an ongoing task.
This means you have the option to use a function factory:
import numpy as np
import numba as nb
def test(a, func=np.median):
@nb.njit
def _test(a):
return func(a)
return _test(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), np.min)
0
>>> test(np.arange(10), np.mean)
4.5
Or to wrap the function argument as jitted-function before passing it in as argument:
import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
return f(a)
@nb.njit
def wrapped_mean(a):
return np.mean(a)
@nb.njit
def wrapped_median(a):
return np.median(a)
>>> test(np.arange(10))
4.5
>>> test(np.arange(10), wrapped_mean)
4.5
>>> test(np.arange(10), wrapped_median)
4.5
Both options have quite a bit of boilerplate and aren't as straight-forward as one might hope.
The function-factory approach also repeatedly creates and compiles functions, so if you often call it with the same function as argument you could use a dictionary to store the known compiled functions:
import numpy as np
import numba as nb
_precompiled_funcs = {}
def test(a, func=np.median):
if func not in _precompiled_funcs:
@nb.njit
def _test(arr):
return func(arr)
result = _test(a)
_precompiled_funcs[func] = _test
return result
return _precompiled_funcs[func](a)
The other approach (using the wrapped and jitted functions) also has some overhead, however it's not really noticeable as long as the arrays you pass in have a significant number of elements (>1000).
If the function you've showed is really the function that you wanted to use I wouldn't use numba on it at all. With such simple tasks that don't exercise the strength of numba (indexing and iterating arrays or heavy number crunching) using Python + NumPy should be faster (or as fast) and much easier to debug and understand:
import numba as nb
def test(a, f=np.median):
return f(a)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With