Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Numba Compatable Memoization

I have just discovered numba, and learnt that optimal performance requires adding @njit to most functions, such that numba rarely exits LLVM mode.

I still have a few expensive/lookup functions that could benefit from memoization, but so far none of my attempts have found a workable solution that compiles without error.

  • Using common decorator functions, before @njit results in a numba not being able to do type inference.
  • Using decorators after @njit fails to compile the decorator
  • Numba doesn't like the use of global variables, even when using numba.typed.Dict
  • Numba doesn't like using closures to store mutable state
  • Removing @njit also causes type errors when called from other @njit functions

What is the correct way to add memoization to functions when working inside numba?

import functools
import time

import fastcache
import numba
import numpy as np
import toolz
from numba import njit

from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize

# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks

# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks

# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache

# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]

# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()

def loop(count):
    for n in range(count):
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()

def main():
    time_start = time.perf_counter()

    count = 10000

    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')

loop(1)  # precache numba

# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s
like image 883
James McGuigan Avatar asked Jul 27 '20 02:07

James McGuigan

People also ask

Is Numba faster than NumPy?

Large data For larger input data, Numba version of function is must faster than Numpy version, even taking into account of the compiling time. In fact, the ratio of the Numpy and Numba run time will depends on both datasize, and the number of loops, or more general the nature of the function (to be compiled).

Should I use Cython or Numba?

Both Cython and Numba speeds up Python code even small number of operations. More the number of operations more is the speed up. However, performance gain by Cython saturates at around 100-150 times of Python. On the other hand, speed up gain by Numba increases steadily with number of operations.

Does Numba support Scipy?

However numba and scipy are still not compatible. Yes, Scipy calls compiled C and Fortran, but it does so in a way that numba can't deal with.

What is Nopython in Numba?

nopython. Numba has two compilation modes: nopython mode and object mode. The former produces much faster code, but has limitations that can force Numba to fall back to the latter. To prevent Numba from falling back, and instead raise an error, pass nopython=True .

1 Answers

You already mentioned that your real code is more complex, but looking at your minimal example, I would recommend the following pattern:

def loop(count):
    expensive_result = expensive()
    for i in range(count):
        do_something(count, expensive_result)

Instead of using a cache, you could pre-compute it outside of the loop and provide the result to the loop body. Instead of using globals, I would recommend you to pass every argument explicitly (always, but especially when using the numba jit).

like image 88
Carlos Horn Avatar answered Oct 18 '22 04:10

Carlos Horn