Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numba TypingError with complex numpy array and native data types

I have a function that deals with complex data type and I am using numba for faster processing. I declare a zero array using numpy, with complex data type, to be filled in later in the function. But while running numba is not able to overload the zero generating function. To reproduce the error I have provided an MWE.

import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex)
    a[idx] = 10
    return a

my_func(4)

The following error is shown where the array a is being initialized.

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)

No implementation of function Function(<built-in function zeros>) found for signature:
zeros(Tuple(Literal[int](10), Literal[int](5)), dtype=Function(<class 'complex'>))
There are 2 candidate implementations:

 Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 511.
    With argument(s): '(UniTuple(int64 x 2), dtype=Function(<class 'complex'>))':
   No match.

I am assuming this has got to do with the data type of the variable a (I need it to be complex). How can I go about this error?

Any help would be appreciated, thanks.

like image 892
learner Avatar asked Nov 07 '22 01:11

learner


1 Answers

Your problem has nothing to do with complex numbers. If you specified a = np.zeros((10, 5), dtype=int), you'd run into the same problem.

While numpy takes python native data types int, float and complex and treats them as np.int32, np.float64 and np.complex128, numba doesn't do that by itself however.

So whenever you specify data types inside jitted functions, you either use numpy data types:

import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=np.complex128)
    a[idx] = 10
    return a

my_func(4)

Or you use numba data types either via direct import:

import numpy as np
from numba import njit, complex128

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex128)
    a[idx] = 10
    return a

my_func(4)

or via types:

import numpy as np
from numba import njit, types

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=types.complex128)
    a[idx] = 10
    return a

my_func(4)

To my knowledge, it really doesn't make a difference, which one of these options you use. Here is the relevant part of the numba documentation.

like image 93
Nic Moetsch Avatar answered Nov 14 '22 21:11

Nic Moetsch