Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make discrete Fourier transform (FFT) in numba.njit?

Tags:

Hello fellow programmers

I am trying to make a discrete Fourier transform in this minimal working example with the numba.njit decorator:

import numba
import numpy as np
import scipy
import scipy.fftpack

@numba.njit
def main():
    wave = [[[0.09254795,  0.10001078,  0.10744892, 0.07755555,  0.08506225, 0.09254795],
          [0.09907245,  0.10706145,  0.11502401,  0.08302302,  0.09105898, 0.09907245],
          [0.09565098,  0.10336405,  0.11105158,  0.08015589,  0.08791429, 0.09565098],
          [0.00181467,  0.001961,    0.00210684,  0.0015207,   0.00166789, 0.00181467]],
         [[-0.45816267, - 0.46058367, - 0.46289091, - 0.45298182, - 0.45562851, -0.45816267],
          [-0.49046506, - 0.49305676, - 0.49552669, - 0.48491893, - 0.48775223, -0.49046506],
          [-0.47352483, - 0.47602701, - 0.47841162, - 0.46817027, - 0.4709057, -0.47352483],
          [-0.00898358, - 0.00903105, - 0.00907629, - 0.008882, - 0.00893389, -0.00898358]],
         [[0.36561472,  0.36057289,  0.355442,  0.37542627,  0.37056626, 0.36561472],
          [0.39139261,  0.38599531,  0.38050268,  0.40189591,  0.39669325, 0.39139261],
          [0.37787385,  0.37266296,  0.36736003,  0.38801438,  0.38299141, 0.37787385],
          [0.00716892,  0.00707006,  0.00696945,  0.0073613,  0.00726601, 0.00716892]]]

    new_fft = scipy.fftpack.fft(wave)


if __name__ == '__main__':
    main()

Output:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'scipy.fftpack' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fftpack\\__init__.py'>)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fftpack.fft(wave)
    ^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fftpack.fft(wave)
    ^


Process finished with exit code 1

Unfortunately scipy.fftpack.fft seems to be a legacy function that is not supported by numba. So I searched for alternatives. I found two:

1. scipy.fft(wave) which is the updated version of the above mentioned legacy function. It produces this error output:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) with parameters (list(list(list(float64))))
No type info available for Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) as a callable.
[1] During: resolving callee type: Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)


File "test2.py", line 21:
def main():
    <source elided>

    new_fft = scipy.fft(wave)
    ^


Process finished with exit code 1

2. np.fft.fft(wave) which seems to be supported but also produces an error:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'numpy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\numpy\\fft\\__init__.py'>)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = np.fft.fft(wave)
    ^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
    <source elided>

    new_fft = np.fft.fft(wave)
    ^


Process finished with exit code 1

Do you know a fft function, that works with the numba.njit decorator?

like image 529
Artur Müller Romanov Avatar asked Jun 05 '20 10:06

Artur Müller Romanov


2 Answers

If you are happy with a 1D DFT you might as well use an FFT. Here, is reported a Numba-friendly implementation fft_1d() working on arbitrary input sizes:

import cmath
import numpy as np
import numba as nb


@nb.jit
def ilog2(n):
    result = -1
    if n < 0:
        n = -n
    while n > 0:
        n >>= 1
        result += 1
    return result


@nb.njit(fastmath=True)
def reverse_bits(val, width):
    result = 0
    for _ in range(width):
        result = (result << 1) | (val & 1)
        val >>= 1
    return result


@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    levels = ilog2(n)
    e_arr = np.empty_like(arr)
    coeff = (-2j if direct else 2j) * cmath.pi / n
    for i in range(n):
        e_arr[i] = cmath.exp(coeff * i)
    result = np.empty_like(arr)
    for i in range(n):
        result[i] = arr[reverse_bits(i, levels)]
    # Radix-2 decimation-in-time FFT
    size = 2
    while size <= n:
        half_size = size // 2
        step = n // size
        for i in range(0, n, size):
            k = 0
            for j in range(i, i + half_size):
                temp = result[j + half_size] * e_arr[k]
                result[j + half_size] = result[j] - temp
                result[j] += temp
                k += step
        size *= 2
    return result


@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
    """1D FFT for arbitrary inputs using chirp z-transform"""
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    m = 1 << (ilog2(n) + 2)
    e_arr = np.empty(n, dtype=np.complex128)
    for i in range(n):
        e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
    result = np.zeros(m, dtype=np.complex128)
    result[:n] = arr * e_arr
    coeff = np.zeros_like(result)
    coeff[:n] = e_arr.conjugate()
    coeff[-n + 1:] = e_arr[:0:-1].conjugate()
    return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m


@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
    return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)


@nb.njit(fastmath=True)
def fft_1d(arr):
    n = len(arr)
    if not n & (n - 1):
        return fft_1d_radix2_rbi(arr)
    else:
        return fft_1d_arb(arr)

Compared to the naïve DFT algorithm (dft_1d() which is fundamentally the same as this), you are gaining orders of magnitude, while you are still a typically a lot slower than np.fft.fft().

vs_dft

The relative speed varies greatly depending on the input sizes. For power-of-2 inputs, this is typically within one order of magnitude of np.fft.fft().

pow2

For non-power-of-2, this is typically within two orders of magnitude of np.fft.fft().

not-pow2

For worst-case (prime numbers or so, here is power-of-2 + 1), this is a times as fast as np.fft.fft().

primes

The non-linear behavior of the FFT timings are the result of the need for a more complex algorithm for arbitrary input sizes that are not power-of-2. This affects both this implementation and the one from np.fft.fft(), but np.fft.fft() contains a lot more optimizations which make it perform much better on average.

Alternate implementations of power-of-2 FFT are shown here.

like image 67
norok2 Avatar answered Oct 12 '22 22:10

norok2


The numba documentation mentioned that np.fft.fft is not support. A solution is to use the objmode context to call python functions that are not supported yet. Only the part inside the objmode context will run in object mode, and therefore can be slow. For you particular case, this part will not be that slow because np.fft.fft is already very fast as pointed by @tstanisl as the first comment of the question. Here is as example

from numba import njit
import numpy as np

@njit()
def compute_fft(x):
   y = np.zeros(., dtype=np.complex128) 
   with objmode(y='type[:]'):
      y = np.fft.fft(x)
   return y

@njit()
def main():
   ...
   x = np.random.randint(100)
   fft_x = compute_fft(x) 
   ...
like image 44
M . Franklin Avatar answered Oct 12 '22 23:10

M . Franklin