Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compare strings in numba-compiled function

I'm looking at the best way to compare strings in a python function compiled using numba jit (no python mode, python 3).

The use case is the following :

import numba as nb

@nb.jit(nopython = True, cache = True)
def foo(a, t = 'default'):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        ...

However, the following error is returned:

Invalid usage of == with parameters (str, const('awesome'))

I tried using bytes but couldn't succeed.

Thanks !


Maurice pointed out the question Python: can numba work with arrays of strings in nopython mode? but I'm looking at native python and not the numpy subset supported in numba.

like image 321
Thomas C. Avatar asked Oct 12 '17 11:10

Thomas C.


2 Answers

For newer numba versions (0.41.0 and later)

Numba (since version 0.41.0) support str in nopython mode and the code as written in the question will "just work". However for your example comparing the strings is much slower than your operation, so if you want to use strings in numba functions make sure the overhead is worth it.

import numba as nb

@nb.njit
def foo_string(a, t):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        return a

@nb.njit
def foo_int(a, t):
    if t == 1:
        return(a**2)
    elif t == 0:
        return(a**3)
    else:
        return a

assert foo_string(100, 'default') == foo_int(100, 0)
%timeit foo_string(100, 'default')
# 2.82 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_int(100, 0)
# 213 ns ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In your case the code is more than 10 times slower using strings.

Since your function doesn't do much it could be better and faster to do the string comparison in Python instead of numba:

def foo_string2(a, t):
    if t == 'awesome':
        sec = 1
    elif t == 'default':
        sec = 0
    else:
        sec = -1
    return foo_int(a, sec)

assert foo_string2(100, 'default') == foo_string(100, 'default')
%timeit foo_string2(100, 'default')
# 323 ns ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

This is still a bit slower than the pure integer version but it's almost 10 times faster than using the string in the numba function.

But if you do a lot of numerical work in the numba function the string comparison overhead won't matter. But simply putting numba.njit on a function, especially if it doesn't do many array operations or number crunching, won't make it automatically faster!

For older numba versions (before 0.41.0):

Numba doesn't support strings in nopython mode.

From the documentation:

2.6.2. Built-in types

2.6.2.1. int, bool [...]

2.6.2.2. float, complex [...]

2.6.2.3. tuple [...]

2.6.2.4. list [...]

2.6.2.5. set [...]

2.6.2.7. bytes, bytearray, memoryview

The bytearray type and, on Python 3, the bytes type support indexing, iteration and retrieving the len().

[...]

So strings aren't supported at all and bytes don't support equality checks.

However you can pass in bytes and iterate over them. That makes it possible to write your own comparison function:

import numba as nb

@nb.njit
def bytes_equal(a, b):
    if len(a) != len(b):
        return False
    for char1, char2 in zip(a, b):
        if char1 != char2:
            return False
    return True

Unfortunately the next problem is that numba cannot "lower" bytes, so you cannot hardcode the bytes in the function directly. But bytes are basically just integers, and the bytes_equal function works for all types that numba supports, that have a length and can be iterated over. So you could simply store them as lists:

import numba as nb

@nb.njit
def foo(a, t):
    if bytes_equal(t, [97, 119, 101, 115, 111, 109, 101]):
        return a**2
    elif bytes_equal(t, [100, 101, 102, 97, 117, 108, 116]):
        return a**3
    else:
        return a

or as global arrays (thanks @chrisb - see comments):

import numba as nb
import numpy as np

AWESOME = np.frombuffer(b'awesome', dtype='uint8')
DEFAULT = np.frombuffer(b'default', dtype='uint8')

@nb.njit
def foo(a, t):
    if bytes_equal(t, AWESOME):
        return a**2
    elif bytes_equal(t, DEFAULT):
        return a**3
    else:
        return a

Both will work correctly:

>>> foo(10, b'default')
1000
>>> foo(10, b'awesome')
100
>>> foo(10, b'awe')
10

However, you cannot specify a bytes array as default, so you need to explicitly provide the t variable. Also it feels hacky to do it that way.

My opinion: Just do the if t == ... checks in a normal function and call specialized numba functions inside the ifs. String comparisons are really fast in Python, just wrap the math/array-intensive stuff in a numba function:

import numba as nb

@nb.njit
def awesome_func(a):
    return a**2

@nb.njit
def default_func(a):
    return a**3

@nb.njit
def other_func(a):
    return a

def foo(a, t='default'):
    if t == 'awesome':
        return awesome_func(a)
    elif t == 'default':
        return default_func(a)
    else:
        return other_func(a)

But make sure you actually need numba for the functions. Sometimes normal Python/NumPy will be fast enough. Just profile the numba solution and a Python/NumPy solution and see if numba makes it significantly faster. :)

like image 65
MSeifert Avatar answered Oct 22 '22 02:10

MSeifert


I'd suggest accepting @MSeifert's answer, but as a another option for these types of problems, consider using an enum.

In python, strings are often used as a sort of enum, and you numba has builtin support for enums so they can be used directly.

import enum

class FooOptions(enum.Enum):
    AWESOME = 1
    DEFAULT = 2

import numba

@numba.njit
def foo(a, t=FooOptions.DEFAULT):
    if t == FooOptions.AWESOME:
        return a**2
    elif t == FooOptions.DEFAULT:
        return a**2
    else:
        return a

foo(10, FooOptions.AWESOME)
Out[5]: 100
like image 26
chrisb Avatar answered Oct 22 '22 02:10

chrisb