Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently find the indices of shared values (with repeats) between two large arrays

Problem description

Let's take this simple array set

            # 0,1,2,3,4,5
a = np.array([1,1,3,4,6])
b = np.array([6,6,1,3])

From these two arrays I want to get the indices of all possible matches. So for number 1 we get 0,2 and 1,2, with the complete output looking like:

0,2 # 1
1,2 # 1
2,3 # 3
4,0 # 6
4,1 # 6

Note that the arrays are (not yet) sorted neither do they only contain unique elements - two conditions often assumed in other answers (see bottom). The above example is very small, however, I have to apply this to ~40K element arrays.


Tried approaches

1.Python loop approach
indx = []
for i, aval in enumerate(a):
    for j, bval in enumerate(b):
        if aval == bval:
            indx.append([i,j])
# [[0, 2], [1, 2], [2, 3], [4, 0], [4, 1]]
2.Python dict approach
adict = defaultdict(list)
bdict = defaultdict(list)
for i, aval in enumerate(a): adict[aval].append(i)
for j, bval in enumerate(b): bdict[bval].append(j)

for val, a_positions in adict.items():
    for b_position in bdict[val]:
        for a_position in a_positions:
            print(a_position, b_position)
3.Numpy where

print(np.where(a.reshape(-1,1) == b))

4. Polars dataframes

Converting it to a dataframe and then using Polars

import polars as pl
a = pl.DataFrame( {'x': a, 'apos':list(range(len(a)))} )
b = pl.DataFrame( {'x': b, 'apos':list(range(len(b)))} )
a.join(b, how='inner', on='x')

"Big data"
On "big" data using Polars seems the fastest now with around 0.02 secs. I'm suprised that creating DataFrames first and then joining them is faster than any other approach I could come up with so curious if there is any other way to beat it :)
a = np.random.randint(0,1000, 40000)
b = np.random.randint(0,1000, 40000)

Using the above data:

  1. python loop: 218s
  2. python dict: 0.03s
  3. numpy.where: 4.5s
  4. polars: 0.02s

How related questions didn't solve this

  • Return common element indices between two numpy arrays, only returns the indexes of matchesin one of the arrays, not both
  • Find indices of common values in two arrays, returns the matching indices of A with B and B with A, but not the paired indices (see example)

Very surprised a DataFrame library is currently the fastest, so curious to see if there are other approaches to beat this speed :) Everything is fine, cython, numba, pythran etc.

like image 205
CodeNoob Avatar asked Oct 20 '25 08:10

CodeNoob


1 Answers

An alternative completely-different solution is to sort the array and retrieve the locations of the sorted array with np.argsort, then get the sorted value, and then walk in lockstep over the two set of locations sorted by value. This last operation can be (again) implemented efficiently in Numba or Cython. It can be actually split in two part: the one finding slices in a and b matching to the same value (similar to a merge operation), and one doing the actual cartesian product for each matching slices. Splitting this in two steps enable the second one (which is expensive) to be computed in parallel if possible (and in-place if possible too). The complexity of finding the matching offsets is O(n log n) with Numpy (one can reach the theoretical optimal O(n) time using a radix sort).

Here is the resulting implementation:

import numba as nb
import numpy as np

# Support both 32-bit and 64-bit integers
@nb.njit(['(int64[::1],int64[::1],int64[::1],int64[::1])', '(int64[::1],int64[::1],int32[::1],int32[::1])'], debug=True)
def find_matching_offsets(a_positions, b_positions, a_sorted_values, b_sorted_values):
    n, m = a_positions.size, b_positions.size
    result = np.empty((n, 4), dtype=np.int32)

    a_pos, b_pos, cur = 0, 0, 0
    while a_pos < n and b_pos < m:
        a_val = a_sorted_values[a_pos]
        b_val = b_sorted_values[b_pos]

        if a_val < b_val:
            a_pos += 1
            continue

        if a_val > b_val:
            b_pos += 1
            continue

        a_end = n
        for i in range(a_pos + 1, n):
            if a_sorted_values[i] != a_val:
                a_end = i
                break

        b_end = m
        for i in range(b_pos + 1, m):
            if b_sorted_values[i] != b_val:
                b_end = i
                break

        result[cur, 0] = a_pos
        result[cur, 1] = a_end
        result[cur, 2] = b_pos
        result[cur, 3] = b_end
        cur += 1

        a_pos = a_end
        b_pos = b_end

    return result[:cur]

@nb.njit(['(int64[::1],int64[::1],int32[:,::1])'], parallel=True)
def do_cartesian_product(a_positions, b_positions, offsets):
    size = 0
    cur = 0
    result_offsets = np.empty(offsets.shape[0], dtype=np.int32)

    # Compute the size of the output
    for i in range(offsets.shape[0]):
        a_pos, a_end, b_pos, b_end = offsets[i]
        assert a_end > a_pos and b_end > b_pos
        result_offsets[cur] = size
        size += (a_end - a_pos) * (b_end - b_pos)
        cur += 1

    assert size > 0
    result = np.empty((size, 2), dtype=np.int32)

    # Generate the output in parallel (or in-place if possible)
    for i in nb.prange(offsets.shape[0]):
        a_pos, a_end, b_pos, b_end = offsets[i]
        offset = result_offsets[i]
        local_cur = 0
        for j in range(a_pos, a_end):
            for k in range(b_pos, b_end):
                local_offset = offset + local_cur
                result[local_offset, 0] = a_positions[j]
                result[local_offset, 1] = b_positions[k]
                local_cur += 1

    return result

def sorted_based_compute(a, b):
    a_positions = np.argsort(a)
    b_positions = np.argsort(b)
    a_sorted_values = a[a_positions]
    b_sorted_values = b[b_positions]
    offsets = find_matching_offsets(a_positions, b_positions, a_sorted_values, b_sorted_values)
    return do_cartesian_product(a_positions, b_positions, offsets)

This solution is faster than the previous one and certainly reach the limit of with what is possible with Numpy/Numba (without making additional assumptions on the input). Here is the performance results (on my 6-core machine):

Python dict:                        24.7 ms
Polars:                             12.3 ms
Dict-based Numba version:           11.3 ms
Sort-based Numpy+Numba version:      5.0 ms  <----

Note that ~60% of the time is spent in the argsort functions and the rest is basically the cartesian product. It can theoretically be improved using a parallel sort but AFAIK this is not possible with Numpy yet (and pretty hard to do in Numba).

like image 63
Jérôme Richard Avatar answered Oct 21 '25 21:10

Jérôme Richard