Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba failure with np.mean

For some reason, numba fails when I add in a axis argument to np.mean. For instance, this gives an error -

import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
    return np.mean(a,-1)

b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))

TypingError: Invalid use of Function(<function mean at 0x000002949B28E1E0>) with argument(s) of type(s): (array(int32, 2d, C), Literal[int](1))
 * parameterized
In definition 0:
    AssertionError: 
    raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
In definition 1:
    AssertionError: 
    raised from C:\ProgramData\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:649
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function mean at 0x000002949B28E1E0>)
[2] During: typing of call at C:/Users/U374235/test.py (11)

However, this works perfectly -

import numpy as np
from numba import jit
@jit(nopython=True)
def num_prac(a):
    return np.mean(a)

b=np.array([[1,2,3,4,5],[1,2,3,4,5]])
print(num_prac(b))
like image 381
desert_ranger Avatar asked Aug 14 '19 18:08

desert_ranger


People also ask

Does Numba work with NumPy?

Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops.

Does Numba speed up NumPy?

Numba can speed things up Of course, it turns out that NumPy has a function that will do this already, numpy. maximum. accumulate . Using that, running only takes 0.03 seconds.

What is Numba used for?

NumPy can be used to perform a wide variety of mathematical operations on arrays. It adds powerful data structures to Python that guarantee efficient calculations with arrays and matrices and it supplies an enormous library of high-level mathematical functions that operate on these arrays and matrices.

What is Numba library in Python?

Numba is an open-source, just-in-time compiler for Python code that developers can use to accelerate numerical functions on both CPUs and GPUs using standard Python functions.


1 Answers

numba doesn't support arguments for np.mean() (including "axis" argument which is not included).

You can do the following to have similar result:

import numpy as np
from numba import jit, prange

a = np.array([[0, 1, 2], [3, 4, 5]])
res_numpy = np.mean(a, -1)

@jit(parallel=True)
def mean_numba(a):

    res = []
    for i in prange(a.shape[0]):
        res.append(a[i, :].mean())

    return np.array(res)

np.array_equal(res_numpy, mean_numba(a))

Related github issue: https://github.com/numba/numba/issues/1269

like image 121
thibaultbl Avatar answered Oct 30 '22 18:10

thibaultbl