Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to return boolean array in numba.njit?

import numpy as np
from numba import njit, float64
from numba.experimental import jitclass

@njit(fastmath=True)
def compare(values1, values2):
    shape = values1.shape[0]
    res = np.zeros(shape, dtype=bool)
    
    for i in range(shape):
        res[i] = x[i] > y[i]
    
    return res

spce = [("x", float64[:]),
        ("y", float64[:]),
        ("z", float64[:]),]
        
@jitclass(spce)
class Math:
    
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    
    def calculate(self):
        i = compare(self.x, self.y)
        return self.z[i]

If I testing like this:

x = np.random.rand(10)
y = np.random.rand(10)
compare(x, y)

It will return:

Traceback (most recent call last):

  File "<ipython-input-25-586dc5d173c7>", line 3, in <module>
    compare(x, y)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)

TypingError: No implementation of function Function(<built-in function zeros>) found for signature:
 
zeros(int64, dtype=Function(<class 'bool'>))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 504.
    With argument(s): '(int64, dtype=Function(<class 'bool'>))':
   No match.

During: resolving callee type: Function(<built-in function zeros>)
During: typing of call at <ipython-input-24-69a4f907fb89> (4)

Finally I need to use it in the jitclass:

x = np.random.rand(10)
y = np.random.rand(10)
z = np.random.rand(10)

m = Math(x, y, z)
m.calculate()

Actually the output is just z[x>y] in numpy, but how i can use in njit & jitclass?

I need both of them for speed up my others code.

If the compare function can be return boolean array, the problem should be solved.

like image 492
onon99buynoodle Avatar asked Jun 23 '26 12:06

onon99buynoodle


1 Answers

You have to use Numba's special bool_ type for that:

import numpy as np
from numba.types import bool_, int_, float32

@njit(bool_[:,:](float32[:,:,:],float32[:,:,:],int_))
def test(im1, im2, j_delta=1):
    diff = ((im1 - im2)**2).sum(axis=2)/3
    mask = np.zeros_like(diff, bool_)  # <--- like so
    for i in range(diff.shape[0]):
        for j in range(diff.shape[1]):
            mask[i,j] = diff[i,j] > 1.0
    return mask

If you replace bool_ by bool or even np.bool, you'll get a compile error.

like image 181
Samuel Prevost Avatar answered Jun 25 '26 01:06

Samuel Prevost



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!