I have a function that I want to compile with numba
, however I need to calculate a factorial inside that function. Unfortunatly numba
doesn't support math.factorial
:
import math
import numba as nb
@nb.njit
def factorial1(x):
return math.factorial(x)
factorial1(10)
# UntypedAttributeError: Failed at nopython (nopython frontend)
I saw that it supported math.gamma
(which could be used to calculate the factorial), however contrary to the real math.gamma
function it doesn't return floats that represent "integral values":
@nb.njit
def factorial2(x):
return math.gamma(x+1)
factorial2(10)
# 3628799.9999999995 <-- not exact
math.gamma(11)
# 3628800.0 <-- exact
and it's slow compared to math.factorial
:
%timeit factorial2(10)
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit math.factorial(10)
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
So I decided to define my own function:
@nb.njit
def factorial3(x):
n = 1
for i in range(2, x+1):
n *= i
return n
factorial3(10)
# 3628800
%timeit factorial3(10)
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
It's still slower than math.factorial
but it's faster than a math.gamma
based numba function and the value is "exact".
So I'm looking for the fastest way to compute the factorial
of a positive integer number (<= 20; to avoid overflow) inside a nopython numba function.
For values <= 20, python is using a lookup table, as was suggested in the comments. https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452
LOOKUP_TABLE = np.array([
1, 1, 2, 6, 24, 120, 720, 5040, 40320,
362880, 3628800, 39916800, 479001600,
6227020800, 87178291200, 1307674368000,
20922789888000, 355687428096000, 6402373705728000,
121645100408832000, 2432902008176640000], dtype='int64')
@nb.jit
def fast_factorial(n):
if n > 20:
raise ValueError
return LOOKUP_TABLE[n]
Called from python it's slightly slower than the python version due to the numba dispatch overhead.
In [58]: %timeit math.factorial(10)
10000000 loops, best of 3: 79.4 ns per loop
In [59]: %timeit fast_factorial(10)
10000000 loops, best of 3: 173 ns per loop
But called inside another numba function can be much faster.
def loop_python():
for i in range(10000):
for n in range(21):
math.factorial(n)
@nb.njit
def loop_numba():
for i in range(10000):
for n in range(21):
fast_factorial(n)
In [65]: %timeit loop_python()
10 loops, best of 3: 36.7 ms per loop
In [66]: %timeit loop_numba()
10000000 loops, best of 3: 73.6 ns per loop
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With