Hello fellow programmers
I am trying to make a discrete Fourier transform
in this minimal working example
with the numba.njit
decorator:
import numba
import numpy as np
import scipy
import scipy.fftpack
@numba.njit
def main():
wave = [[[0.09254795, 0.10001078, 0.10744892, 0.07755555, 0.08506225, 0.09254795],
[0.09907245, 0.10706145, 0.11502401, 0.08302302, 0.09105898, 0.09907245],
[0.09565098, 0.10336405, 0.11105158, 0.08015589, 0.08791429, 0.09565098],
[0.00181467, 0.001961, 0.00210684, 0.0015207, 0.00166789, 0.00181467]],
[[-0.45816267, - 0.46058367, - 0.46289091, - 0.45298182, - 0.45562851, -0.45816267],
[-0.49046506, - 0.49305676, - 0.49552669, - 0.48491893, - 0.48775223, -0.49046506],
[-0.47352483, - 0.47602701, - 0.47841162, - 0.46817027, - 0.4709057, -0.47352483],
[-0.00898358, - 0.00903105, - 0.00907629, - 0.008882, - 0.00893389, -0.00898358]],
[[0.36561472, 0.36057289, 0.355442, 0.37542627, 0.37056626, 0.36561472],
[0.39139261, 0.38599531, 0.38050268, 0.40189591, 0.39669325, 0.39139261],
[0.37787385, 0.37266296, 0.36736003, 0.38801438, 0.38299141, 0.37787385],
[0.00716892, 0.00707006, 0.00696945, 0.0073613, 0.00726601, 0.00716892]]]
new_fft = scipy.fftpack.fft(wave)
if __name__ == '__main__':
main()
Output:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'scipy.fftpack' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fftpack\\__init__.py'>)
File "test2.py", line 21:
def main():
<source elided>
new_fft = scipy.fftpack.fft(wave)
^
[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)
File "test2.py", line 21:
def main():
<source elided>
new_fft = scipy.fftpack.fft(wave)
^
Process finished with exit code 1
Unfortunately scipy.fftpack.fft
seems to be a legacy function that is not supported by numba
. So I searched for alternatives. I found two:
1.
scipy.fft(wave)
which is the updated version of the above mentioned legacy function. It produces this error output:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) with parameters (list(list(list(float64))))
No type info available for Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) as a callable.
[1] During: resolving callee type: Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)
File "test2.py", line 21:
def main():
<source elided>
new_fft = scipy.fft(wave)
^
Process finished with exit code 1
2.
np.fft.fft(wave)
which seems to be supported but also produces an error:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'numpy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\numpy\\fft\\__init__.py'>)
File "test2.py", line 21:
def main():
<source elided>
new_fft = np.fft.fft(wave)
^
[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)
File "test2.py", line 21:
def main():
<source elided>
new_fft = np.fft.fft(wave)
^
Process finished with exit code 1
Do you know a fft
function, that works with the numba.njit
decorator?
If you are happy with a 1D DFT you might as well use an FFT.
Here, is reported a Numba-friendly implementation fft_1d()
working on arbitrary input sizes:
import cmath
import numpy as np
import numba as nb
@nb.jit
def ilog2(n):
result = -1
if n < 0:
n = -n
while n > 0:
n >>= 1
result += 1
return result
@nb.njit(fastmath=True)
def reverse_bits(val, width):
result = 0
for _ in range(width):
result = (result << 1) | (val & 1)
val >>= 1
return result
@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
levels = ilog2(n)
e_arr = np.empty_like(arr)
coeff = (-2j if direct else 2j) * cmath.pi / n
for i in range(n):
e_arr[i] = cmath.exp(coeff * i)
result = np.empty_like(arr)
for i in range(n):
result[i] = arr[reverse_bits(i, levels)]
# Radix-2 decimation-in-time FFT
size = 2
while size <= n:
half_size = size // 2
step = n // size
for i in range(0, n, size):
k = 0
for j in range(i, i + half_size):
temp = result[j + half_size] * e_arr[k]
result[j + half_size] = result[j] - temp
result[j] += temp
k += step
size *= 2
return result
@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
"""1D FFT for arbitrary inputs using chirp z-transform"""
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
m = 1 << (ilog2(n) + 2)
e_arr = np.empty(n, dtype=np.complex128)
for i in range(n):
e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
result = np.zeros(m, dtype=np.complex128)
result[:n] = arr * e_arr
coeff = np.zeros_like(result)
coeff[:n] = e_arr.conjugate()
coeff[-n + 1:] = e_arr[:0:-1].conjugate()
return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m
@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)
@nb.njit(fastmath=True)
def fft_1d(arr):
n = len(arr)
if not n & (n - 1):
return fft_1d_radix2_rbi(arr)
else:
return fft_1d_arb(arr)
Compared to the naïve DFT algorithm (dft_1d()
which is fundamentally the same as this), you are gaining orders of magnitude, while you are still a typically a lot slower than np.fft.fft()
.
The relative speed varies greatly depending on the input sizes.
For power-of-2 inputs, this is typically within one order of magnitude of np.fft.fft()
.
For non-power-of-2, this is typically within two orders of magnitude of np.fft.fft()
.
For worst-case (prime numbers or so, here is power-of-2 + 1), this is a times as fast as np.fft.fft()
.
The non-linear behavior of the FFT timings are the result of the need for a more complex algorithm for arbitrary input sizes that are not power-of-2. This affects both this implementation and the one from np.fft.fft()
, but np.fft.fft()
contains a lot more optimizations which make it perform much better on average.
Alternate implementations of power-of-2 FFT are shown here.
The numba documentation mentioned that np.fft.fft is not support. A solution is to use the objmode context to call python functions that are not supported yet. Only the part inside the objmode context will run in object mode, and therefore can be slow. For you particular case, this part will not be that slow because np.fft.fft is already very fast as pointed by @tstanisl as the first comment of the question. Here is as example
from numba import njit
import numpy as np
@njit()
def compute_fft(x):
y = np.zeros(., dtype=np.complex128)
with objmode(y='type[:]'):
y = np.fft.fft(x)
return y
@njit()
def main():
...
x = np.random.randint(100)
fft_x = compute_fft(x)
...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With