Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

workaround for numpy np.all axis argument; compatibility with numba

I have a function that, given a numpy array of xy coordinates, it filters those which lies within a box of side L

import numpy as np
from numba import njit

np.random.seed(65238758)

L = 10
N = 1000
xy = np.random.uniform(0, 50, (N, 2))
box = np.array([
    [0,0],  # lower-left
    [L,L]  # upper-right
]) 

def sinjit(xy, box):
    mask = np.all(np.logical_and(xy >= box[0], xy <= box[1]), axis=1)
    return xy[mask]

If I run this function it returns the correct result:

sinjit(xy, box)

Output: array([[5.53200522, 7.86890708],
       [4.60188554, 9.15249881],
       [9.072563  , 5.6874726 ],
       [4.48976127, 8.73258166],
       ...
       [6.29683131, 5.34225758],
       [2.68057087, 5.09835442],
       [5.98608603, 4.87845464],
       [2.42049857, 6.34739079],
       [4.28586677, 5.79125413]])

But, as I want to speed this task in a loop by using numba, there exists a compatibility problem with "axis" argument in np.all function (it is not implemented in nopython mode). So, my question is, is it possible to avoid such argument in any way? any workaround?

like image 943
Dani Avatar asked Oct 16 '22 04:10

Dani


1 Answers

I really, really, really wish numba supported optional keyword arguments. Until it does, I'm pretty much ignoring it. However, some hacks are possible here.

You'll need to take extra care for anything that has other than 2 dimensions or has lengths that might be zero.

import numpy as np
from numba import njit

@njit(cache=True)
def np_all_axis0(x):
    """Numba compatible version of np.all(x, axis=0)."""
    out = np.ones(x.shape[1], dtype=np.bool8)
    for i in range(x.shape[0]):
        out = np.logical_and(out, x[i, :])
    return out

@njit(cache=True)
def np_all_axis1(x):
    """Numba compatible version of np.all(x, axis=1)."""
    out = np.ones(x.shape[0], dtype=np.bool8)
    for i in range(x.shape[1]):
        out = np.logical_and(out, x[:, i])
    return out

@njit(cache=True)
def np_any_axis0(x):
    """Numba compatible version of np.any(x, axis=0)."""
    out = np.zeros(x.shape[1], dtype=np.bool8)
    for i in range(x.shape[0]):
        out = np.logical_or(out, x[i, :])
    return out

@njit(cache=True)
def np_any_axis1(x):
    """Numba compatible version of np.any(x, axis=1)."""
    out = np.zeros(x.shape[0], dtype=np.bool8)
    for i in range(x.shape[1]):
        out = np.logical_or(out, x[:, i])
    return out

if __name__ == '__main__':
    x = np.array([[1, 1, 0, 0], [1, 0, 1, 0]], dtype=bool)
    np.testing.assert_array_equal(np.all(x, axis=0), np_all_axis0(x))
    np.testing.assert_array_equal(np.all(x, axis=1), np_all_axis1(x))
    np.testing.assert_array_equal(np.any(x, axis=0), np_any_axis0(x))
    np.testing.assert_array_equal(np.any(x, axis=1), np_any_axis1(x))

I'm not sure how performant this will be, but if you really need to call the function in a higher level jit'ed function, then this will let you do it.

like image 57
DStauffman Avatar answered Oct 18 '22 13:10

DStauffman