For some reason, numba fails when I add in a axis argument to np.mean. For instance, this gives an error -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a,-1)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
TypingError: Invalid use of Function(<function mean at 0x000002949B28E1E0>) with argument(s) of type(s): (array(int32, 2d, C), Literal[int](1))
* parameterized
In definition 0:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
In definition 1:
AssertionError:
raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function mean at 0x000002949B28E1E0>)
[2] During: typing of call at C:/Users/U374235/test.py (11)
However, this works perfectly -
import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
return np.mean(a)
b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops.
Numba can speed things up Of course, it turns out that NumPy has a function that will do this already, numpy. maximum. accumulate . Using that, running only takes 0.03 seconds.
NumPy can be used to perform a wide variety of mathematical operations on arrays. It adds powerful data structures to Python that guarantee efficient calculations with arrays and matrices and it supplies an enormous library of high-level mathematical functions that operate on these arrays and matrices.
Numba is an open-source, just-in-time compiler for Python code that developers can use to accelerate numerical functions on both CPUs and GPUs using standard Python functions.
numba doesn't support arguments for np.mean() (including "axis" argument which is not included).
You can do the following to have similar result:
import numpy as np
from numba import jit, prange
a = np.array([[0, 1, 2], [3, 4, 5]])
res_numpy = np.mean(a, -1)
@jit(parallel=True)
def mean_numba(a):
res = []
for i in prange(a.shape[0]):
res.append(a[i, :].mean())
return np.array(res)
np.array_equal(res_numpy, mean_numba(a))
Related github issue: https://github.com/numba/numba/issues/1269
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With