Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy fft is fast for lengths that are products of small primes, but how small?

I've seen several examples showing that if the input length is a product of 2,3,5,7 etc. then numpy's fft implementation is fast. But what is the largest prime number that is still considered "small" here?

like image 258
Tavin Avatar asked Jan 03 '23 10:01

Tavin


1 Answers

Note that scipy's FFT has radices of 2, 3, 4, and 5 (reference) . I assume numpy may have a similar implementation, which would make 5 the largest efficient prime factor in FFT lengths.


Empirically, the largest prime I'd consider "small" for the purpose of FFT performance is 11. But any input length of less than about 30 is going to be pretty fast for practical purposes. Any algorithmic performance gains are certainly going to be dwarved by Python's execution overhead. Things are getting more interesting for higher input lengths.

Here are some performance results for small FFTs (median execution time over 500 batches of 1000 FFTs each):

enter image description here

I have marked prime valued n in red and power-of-twos in green.

Mark the following observations:

  • in general the FFT is slow for primes but fast for power of twos. This is pretty much expected and validates the results.

  • no performance difference for n <=11 was measurable. This may be due to FFT implementation or execution overhead.

  • Primes of 31 (maybe 29) and higher are clearly slower than other nearby values.

  • There are some non-power-of-two values that also give good performance. This are probably highly composite numbers.

The measurements were performed like this:

import numpy as np
import matplotlib.pyplot as plt
from time import perf_counter as time


N = np.arange(2, 65)
times = np.empty((500, N.size))
for i, n in enumerate(N):
    for r in range(times.shape[0]):
        x = np.random.randn(1000, n)
        t = time()
        y = np.fft.fft(x, axis=-1)
        t = time() - t
        times[r, i] = t


med = np.median(times, axis=0)
plt.plot(N, med, 'k')

primes = np.array([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61])
plt.plot(primes, med[primes-2]+0.0005, 'rx', label='n = prime')

ptwos = np.array([2, 4, 8, 16, 32, 64])
plt.plot(ptwos, med[ptwos-2]-0.0005, 'gx', label='n = 2**k')

plt.legend(loc='best')
plt.xlabel('n')
plt.ylabel('time')
plt.grid()
plt.show()
like image 129
MB-F Avatar answered Jan 05 '23 23:01

MB-F