Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba: calling jit with explicit signature using arguments with default values

I'm using numba to make some functions containing cycles on numpy arrays.

Everything is fine and dandy, I can use jit and I learned how to define the signature.

Now I tried using jit on a function with optional arguments, e.g.:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b

This works, but if instead of optional(float) I use optional(float64) it doesn't (same thing with int or int64). I lost 1 hour trying to figure this syntax out (actually, a friend of mine found this solution by chance because he forgot to write the 64 after the float), but, for the love of me, I cannot understand why this is so. I can't find anything on the internet and numba's docs on the topic are scarce at best (and they specify that optional should take a numba type).

Does anyone know how this works? What am I missing?

like image 365
gionni Avatar asked Sep 08 '17 19:09

gionni


People also ask

What does Numba JIT do?

Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.

Is Numba faster than NumPy?

Numba Sort is Significantly Slower Than NumPy Sort.

Can Numba handle strings?

Numba supports (Unicode) strings in Python 3. Strings can be passed into nopython mode as arguments, as well as constructed and returned from nopython mode. As in Python, slices (even of length 1) return a new, reference counted string.

Is Numba better than NumPy?

Large dataFor larger input data, Numba version of function is must faster than Numpy version, even taking into account of the compiling time. In fact, the ratio of the Numpy and Numba run time will depends on both datasize, and the number of loops, or more general the nature of the function (to be compiled).


1 Answers

Ah, but the exception message should give a hint:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)

That means optional is the wrong choice here. In fact optional represents None or "that type". But you want an optional argument, not an argument that could be a float and None, e.g.:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

I suspect that it just "happens" to work for optional(float) because float is just an "arbitary Python object" from numbas point of view, so with optional(float) you could pass anything in there (this apparently includs not giving the argument). With optional(float64) it could only be None or a float64. That category isn't broad enough to allow not providing the argument.

It works if you give the type Omitted:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0

However it seems like Omitted isn't actually included in the documentation and that it has some "rough edges". For example it can't be compiled in nopython mode with that signature, even though it seems possible without signature:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0
like image 191
MSeifert Avatar answered Nov 14 '22 21:11

MSeifert