Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba dictionary: signature in JIT() decorator

My function takes a list of numpy arrays and a dictionary (or a list of dictionaries) as input arguments and returns a list of values. The list of numpy arrays is long, and arrays may be of different shape. Though I can pass numpy arrays separately, for housekeeping purposes I really would like to form a tuple of numpy arrays and pass them as such into my function. Without dictionary (which is specially formed according to numba >=0.43) the whole setup works fine - see the script below. Because the structure of input and output is of Tuple form, JIT requires signature - it cannot figure out the type of data structure without it. However no matter how I try to declare my dictionary 'd' into the JIT decorator, I cannot manage to get the script working. Please help with ideas or a solution if one exists.

Many thanks

'''python:

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

@njit(  'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ))'  )

def somefunction(lst_arr):
    arr1, arr2 = lst_arr

    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ,prod)
    return result

a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)
print(a,b)

print(somefunction(arg))


# ~~ The Dict.empty() constructs a typed dictionary.
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,)

d['k1'] = 1.5
d['k2'] = 0.5

'''

I expect to pass 'd'-dictionary into 'somefunction' and use it inside with dict keys...Form example as follows: result = (summ * d['k1'], prod * d['k2'])

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

@njit(  'Tuple( (f8,f8) )(Tuple( (f8[:],f8[:]) ), Dict)'  )

def somefunction(lst_arr, mydict):
    arr1, arr2 = lst_arr

    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ*mydict['k1'],prod*mydict['k2'])
    return result

# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(5)+11.0
arg = (a,b)

# ~~ Input dictionary for the function 
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64)

d['k1'] = 1.5
d['k2'] = 0.5


# ~~ Run function and print results
print(somefunction(arg, d))
like image 968
Yevgeniy Karpekin Avatar asked Nov 06 '22 13:11

Yevgeniy Karpekin


1 Answers

I am using the version 0.45.1. You can simply pass the dictionary without having to declare the type in the dictionary:

d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64[:],
)
d['k1'] = np.arange(5) + 1.0
d['k2'] = np.arange(5) + 11.0

# Numba will infer the type on it's own.
@njit
def somefunction2(d):
    prod = 1

    # I am assuming you want sum of second array and product of second
    result = (d['k2'].sum(), d['k1'].prod())

    return result

print(somefunction(d))
# Output : (65.0, 120.0)

For reference, you check this example from the official documentation.

Update:
In your case you can simply let jit infer the types on it's own and it should work, the following code works for me:

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
from numba.types import DictType

# Let jit infer the types on it's own
@njit
def somefunction(lst_arr, mydict):
    arr1, arr2 = lst_arr
    summ = 0
    prod = 1
    for i in arr2:
        summ += i
    for j in arr1:
        prod *= j

    result = (summ*mydict['k1'],prod*mydict['k2'])
    return result

# ~~ Input numpy arrays
a = np.arange(5)+1.0
b = np.arange(10)+11.0  #<--------------- This is of different shape 
arg = (a,b)

# ~~ Input dictionary for the function 
d = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64)

d['k1'] = 1.5
d['k2'] = 0.5


# This works now
print(somefunction(arg, d))

You can see the official documentation here:

Unless necessary, it is recommended to let Numba infer argument types by using the signature-less variant of @jit.

I tried various methods, but this is the only one that worked for the problem you specified.

like image 74
Gambit1614 Avatar answered Nov 15 '22 06:11

Gambit1614