Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python numba list of list of tuples

Tags:

python

numba

I'm trying to speed up my algorithm with numba, after no further improvements where left by using numpy and optimizing.

I have a function that does some computations in a large 2-fold nested loop:

import random
from numba import njit

@njit()
def decide_if_vaild():
    return bool(random.getrandbits(1))

@njit()
def decide_what_bin(bins):
    return random.randint(0, bins-1)

@njit()
def foo(bins, loops):
    results = [[] for _ in range(bins)]

    for i in range(loops):
        for j in range(i+1, loops):
            happy = decide_if_vaild()
            bin = decide_what_bin(bins)
            if happy:
                results[bin].append( (i,j) )
                # or
                # results[bin].append( [i,j] )
    return results

if __name__ == '__main__':
    x = foo(3,100)

If I run this minimal example above I (as expected) get a typing error:

  File "C:\Users\xxx\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\typeinfer.py", line 104, in getone
    assert self.type is not None
numba.errors.InternalError: 
[1] During: typing of call at C:/Users/xxx/minimal_example.py (21)
--%<-----------------------------------------------------------------

File "minimal_example.py", line 21

The problem is the: "results[bin].append( (i,j) )" where I try to add a tuple (does not work with lists either) to the list.

The number of bins is kown in advance, but how many elements (2-tuple or list or np.array) depends on how often decide_if_vaild evaluates to True, and as I don't know how often this will be and the computation is quite expensive I don't know any other workaround.

Any good ideas how I could generate the result in the jitted function and return it, or pass a global container that this function could fill?

This probably falls back to:

numba.errors.LoweringError: Failed at nopython (nopython mode backend)
list(list(list(int64))): unsupported nested memory-managed object

where a similiar issue for list(list(int64)) (https://github.com/numba/numba/issues/2560) was resolved in numba 0.39.0 as of https://github.com/numba/numba/pull/2840

like image 721
iR0Nic Avatar asked Jan 30 '26 08:01

iR0Nic


1 Answers

I have now implemented following workaround, even though it does not answer the question completly, it might be a suitable approach for others struggeling with this issue:

@njit()
def foo(bins, loops):
    results = []
    mapping = []

    for i in range(loops):
        for j in range(loops+1, size):
            happy = decide_if_vaild()
            bin = decide_what_bin(bins)
            if happy:
                results.append( (i,j) )
                mapping.append( bin )
    return results, mapping

This will return a list of tuples (supported as of numba 0.39.0) and a list of mappings, where mapping[i] contains the bin for results[i]. Now the jit-compiler works smoothly and I can unpack the results outside of jit.

like image 200
iR0Nic Avatar answered Jan 31 '26 21:01

iR0Nic