Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

NumPy zeros in Numba function is not working when multi-dimensional shape is a list

I tried giving numba a go, as I was told it works very well for numerical/scientific computing applications. However, it seems that I've already run into a problem in the following scenario:

I have a function that computes a 12x12 Jacobian matrix, represented by a numpy array, and then returns this Jacobian. However, when I attempt to decorate said function with @numba.njit, I get the following error:

This is not usually a problem with Numba itself but instead often caused by the use of unsupported features or an issue in resolving types.

As a basic example of my usage, the following code tries to declare a 12x12 numpy zero matrix, but it fails:

import numpy as np
import numba

@numba.njit
def numpy_matrix_test():
    A = np.zeros([12,12])
    return A

A_out = numpy_matrix_test()
print(A_out)

Since I assumed declaring numpy arrays in such a way was common enough that numba would be able to handle them, I'm quite surprised.

like image 223
indigoblue Avatar asked Dec 23 '22 22:12

indigoblue


1 Answers

The assumption that the functions called in a numba jitted function are the same functions when not used in a numba function is actually wrong (but understandable). In reality numba (behind the scenes) delegates to its own functions instead of using the "real" NumPy functions.

So it's not really np.zeros that is called in the jitted function, it's their own function. So some differences between Numba and NumPy are unavoidable.

For example you cannot use a list for the shape, it has to be a tuple (lists and arrays produce the exception you've encountered). So the correct syntax would be:

@numba.njit
def numpy_matrix_test():
    A = np.zeros((12, 12))
    return A

Something similar applies to the dtype argument. It has to be a real NumPy/numba type, a Python type cannot be used:

@numba.njit
def numpy_matrix_test():
    A = np.zeros((12, 12), dtype=int)  # to make it work use numba.int64 instead of int here
    return A

Even if "plain" NumPy allows it:

np.zeros((12, 12), dtype=int)
like image 133
MSeifert Avatar answered Dec 25 '22 12:12

MSeifert