Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to find all pairs of close numbers in a Numpy array

Say I have a Numpy array of N = 10 random float numbers:

import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)

out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963 
         5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]

I want to find all unique pairs of numbers that are not different from each other more than a tolerance tol = 1. (i.e. absolute difference <= 1). Specifically, I want to get all unique pairs of indexes. The indexes of each close pair should be sorted, and all close pairs should be sorted by the first index. I managed to write the following working code:

def all_close_pairs(arr, tol=1.):
    res = set()
    for i, x1 in enumerate(arr):
        for j, x2 in enumerate(arr):
            if i == j:
                continue
            if np.isclose(x1, x2, rtol=0., atol=tol):
                res.add(tuple(sorted([i, j])))
    res = np.array(list(res))
    return res[res[:,0].argsort()]

print(all_close_pairs(arr, tol=1.))

out[2]: [[1 5]
         [2 4]
         [3 7]
         [3 9]
         [7 9]]

However, in reality I have an array of N = 1000 numbers, and my code becomes extremely slow due to the nested for loops. I believe there are much more efficient ways to do this with Numpy vectorization. Does anyone know the fastest way to do this in Numpy?

like image 698
Shaun Han Avatar asked Dec 07 '22 09:12

Shaun Han


1 Answers

One efficient solution is to first sort the input values using index = np.argsort(). Then, you can generate the sorted array using arr[index], and then iterate over the close values in quasi-linear time if the number of pair is small on a fast contiguous array. If the number of pair is big, then the complexity is quadratic due to the quadratic number of pair generated. THe resulting complexity is: O(n log n + m) where n is the size of the input array and m is the number of pair produced.

To find values close to each other, one efficient way is to iterate over the value using Numba. Indeed, while it might be possible in Numpy, it will likely not be efficient due to the variable number of value to be compared. Here is an implementation:

import numba as nb

@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
    res = []
    for i in range(arr.size):
        val = arr[i]
        # Iterate over the close numbers (only once)
        for j in range(i+1, arr.size):
            # Sadly neither np.isclose or np.abs are implemented in Numba so far
            if max(val, arr[j]) - min(val, arr[j]) >= tol:
                break
            res.append((i, j))
    if len(res) == 0: # No pairs: we need to help Numpy to know the shape
        return np.empty((0, 2), dtype=np.int32)
    return np.array(res, dtype=np.int32)

Finally, the indices need to be update to reference the indices in the unsorted array and not the sorted one. You can do that using index[result].

Here is the resulting code:

index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])

Here is the result (the order is not the same as in the question but you could sort it if needed):

array([[9, 3],
       [9, 7],
       [3, 7],
       [1, 5],
       [4, 2]])

Improving the complexity of the algorithm

If you need a faster algorithm, then you can use another output format: you can for each input value provide the min/max range of values close to the target input value. To find the range, you can use a binary search (see: np.searchsorted) on the sorted array. The resulting algorithm runs in O(n log n). However, you cannot get the indices in the unsorted array since the range would be non contiguous.

Benchmark

Here are performance results on a random input with 1000 items and a tolerance of 1.0, on my machine:

Reference implementation:   ~17000 ms             (x 1)
Angelicos' implementation:    1773 ms           (x ~10)
Rivers' implementation:        122 ms           (x 139)
Rchome's implementation:        20 ms           (x 850)
Chris' implementation:           4.57 ms       (x 3720)
This implementation:             0.67 ms      (x 25373)
like image 59
Jérôme Richard Avatar answered Jan 26 '23 01:01

Jérôme Richard