A follow up from this question: function types in numba.
I'm writing a function that needs to take a generator as one of its arguments. It's too complicated to paste in here, so consider this toy example:
def take_and_sum(gen):
@numba.jit(nopython=False)
def inner(n):
s = 0
for _ in range(n):
s += next(gen)
return s
return inner
It returns the sum of the first n
elements of the generator. Example usage:
@numba.njit()
def odd_numbers():
n = 1
while True:
yield n
n += 2
take_and_sum(odd_numbers())(3) # prints 9
It's curried because I would like to compile with nopython=True
and then I can't pass gen
(a pyobject
) as an argument. Unfortunately, with nopython=True
I get an error:
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'gen'
even though I nopython
compiled my generator.
What's really confusing about this is that hard-coding the input works:
def take_and_sum():
@numba.njit()
def inner(n):
gen = odd_numbers()
s = 0.0
for _ in range(n):
s += next(gen)
return s
return inner
take_and_sum()(3)
I also tried turning my generator into a class instead:
@numba.jitclass({'n': numba.uint})
class Odd:
def __init__(self):
self.n = 1
def next(self):
n = self.n
self.n += 2
return n
Again, this works in object mode, but in nopython mode I get the unsearchable:
LoweringError: Failed at nopython (nopython mode backend)
Internal error:
NotImplementedError: instance.jitclass.Odd#4aa9758<n:uint64> as constant unsupported
Recursive callsNumba is able to type-infer recursive functions without specifying the function type signature (which is required in numba 0.28 and earlier). Recursive calls can even call into a different overload of the function.
Numba supports (Unicode) strings in Python 3.
Numba supports list comprehension, but not the creation of nested list.
Numba is an open source JIT compiler that translates a subset of Python and NumPy code into fast machine code.
I can't actually solve your problem because it's as far as I know simply not possible. I'm just highlighting some aspects (valid for numba 0.30
):
You can't create a numba-jitclass
generator:
import numba
@numba.jitclass({'n': numba.uint})
class Odd:
def __init__(self):
self.n = 1
def __iter__(self):
return self
def __next__(self):
n = self.n
self.n += 2
return n
Just try:
>>> next(Odd())
TypeError: 'Odd' object is not an iterator
when you remove the numba.jitclass
it works:
>>> next(Odd())
1
Your examples with the hardcoded generator are not equivalent. Your original attempt creates a generator object passes it to a numba function and it modifies the generator. You would expect it to update the state of the generator.
>>> t = odd_numbers()
>>> take_and_sum(t)(3)
9
>>> next(t) # State has been updated, unfortunatly that requires nopython=False!
7
But that's simply not possible with numba (yet).
The second example is different because you create the generator each time you call the function, so there is no state outside your function that needs to be updated:
>>> take_and_sum()(3) # using your hardcoded version
9.0
>>> take_and_sum()(3) # no updated state so this returns the same:
9.0
It's definetly possible to change it but without the option to use arbitary functions:
@numba.jitclass({'n': numba.uint})
class Odd:
def __init__(self):
self.n = 1
def calculate(self, n):
s = 0.0
for _ in range(n):
s += self.n
self.n += 2
return s
>>> x = Odd()
>>> x.calculate(3)
9.0
>>> x.calculate(3)
27.0
I know that's not what you wanted but at least it's somehow works :-)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With