Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Array of ints in numba

I am calculating the most frequent number in a vector of int8s. Numba complains when I set up a counter array of ints:

@jit(nopython=True)
def freq_int8(y):
    """Find most frequent number in array"""
    count = np.zeros(256, dtype=int)
    for val in y:
        count[val] += 1
    return ((np.argmax(count)+128) % 256) - 128

Calling it I get the following error:

TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))

If I delete dtype=int it works and I get a decent speedup. I am however puzzled as to why declaring an array of ints isn't working. Is there a known workaround, and would there be any efficiency gain worth having here?

Background: I am trying to shave microseconds off some numpy-heavy code. I am especially being hurt by numpy.median, and have been looking into Numba, but am struggling to improve on median. Finding the most frequent number is an acceptable alternative to median, and here I've been able to gain some performance. The above numba code is also faster than numpy.bincount.

Update: After input in the accepted answer, here's an implementation of median for int8 vectors. It is roughly an order of magnitude faster than numpy.median:

@jit(nopython=True)
def median_int8(y):
    N2 = len(y)//2
    count = np.zeros(256, dtype=np.int32)
    for val in y:
        count[val] += 1
    cs = 0
    for i in range(-128, 128):
        cs += count[i]
        if cs > N2:
            return float(i)
        elif cs == N2:
            j = i+1
            while count[j] == 0:
                j += 1
            return (i + j)/2

Surprisingly, the performance difference is even greater for short vectors, apparently due to overhead in numpy vectors:

>>> a = np.random.randint(-128, 128, 10)

>>> %timeit np.median(a)
    The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 20.8 µs per loop

>>> %timeit median_int8(a)
    The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
    1000000 loops, best of 3: 593 ns per loop

This overhead is so large, I'm wondering if something is wrong.

like image 967
DNF Avatar asked May 18 '16 10:05

DNF


1 Answers

Just a quick note, finding the most frequent number is normally called mode, and it is as similar to the median as it is the mean... in which case np.mean will be considerably faster. Unless you have some constrains or particularities in your data, there is no guarantee that the mode approximates the median.

If you still want to calculate the mode of a list of integer numbers, np.bincount, as you mention, should be enough (if numba is faster, it shouldn't be by much):

count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128

Note I've added the minlength parameter to np.bincount just so it returns the same 256 length list that you have in your code. But it is completely unnecessary in practice, as you only want the argmax, np.bincount (without minlength) will return a list which length is the maximum number in y.

As for the numba error, replacing dtype=int with dtype=np.int32 should solve the problem. int is a python function, and you are specifying nopython in the numba header. If you remove nopython, then either dtype=int or dtype='i' will also work (having the same effect).

like image 177
Imanol Luengo Avatar answered Nov 13 '22 01:11

Imanol Luengo