Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: find row-wise common element efficiently

Suppose we are given two 2D numpy arrays a and b with the same number of rows. Assume furthermore that we know that each row i of a and b has at most one element in common, though this element may occur multiple times. How can we find this element as efficiently as possible?

An example:

import numpy as np

a = np.array([[1, 2, 3],
              [2, 5, 2],
              [5, 4, 4],
              [2, 1, 3]])

b = np.array([[4, 5],
              [3, 2],
              [1, 5],
              [0, 5]])

desiredResult = np.array([[np.nan],
                          [2],
                          [5],
                          [np.nan]])

It is easy to come up with a streightforward implementation by applying intersect1d along the first axis:

from intertools import starmap

desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))

Apperently, using python's builtin set operations is even quicker. Converting the result to the desired form is easy.

However, I need an implementation as efficient as possible. Hence, I do not like the starmap, as I suppose that it requires a python call for every row. I would like a purely vectorized option, and would be happy, if this even exploitet our additional knowledge that there is at most one common value per row.

Does anyone have ideas how I could speed up the task and implement the solution more elegantly? I would be okay with using C code or cython, but coding effort should be not too much.

like image 377
Samufi Avatar asked Jul 05 '19 00:07

Samufi


2 Answers

Approach #1

Here's a vectorized one based on searchsorted2d -

# Sort each row of a and b in-place
a.sort(1)
b.sort(1)

# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)

# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0

# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b

# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)

# Finally use np.where to choose between valid match 
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)

Approach #2

Numba-based one for memory efficiency -

from numba import njit

@njit(parallel=True)
def numba_f1(a,b,out):
    n,a_ncols = a.shape
    b_ncols = b.shape[1]
    for i in range(n):
        for j in range(a_ncols):
            for k in range(b_ncols):
                m = a[i,j]==b[i,k]
                if m:
                    break
            if m:
                out[i] = a[i,j]
                break
    return out

def find_first_common_elem_per_row(a,b):
    out = np.full(len(a),np.nan)
    numba_f1(a,b,out)
    return out

Approach #3

Here's another vectorized one based on stacking and sorting -

r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)

Approach #4

For an elegant one, we can make use of broadcasting for a resource-hungry method -

m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)
like image 65
Divakar Avatar answered Sep 29 '22 08:09

Divakar


Doing some research, I found that checking whether two lists are disjoint runs in O(n+m), whereby n and m are the lengths of the lists (see here). The idea is that instertion and lookup of elements run in constant time for hash maps. Therefore, inserting all elements from the first list into a hashmap takes O(n) operations, and checking for each element in the second list whether it is already in the hash map takes O(m) operations. Therefore, solutions based on sorting, which run in O(n log(n) + m log(m)), are not optimal asymptotically.

Though the solutions by @Divakar are highly efficient in many use cases, they are less efficient, if the second dimension is large. Then, a solution based on hash maps is better suited. I have implemented it as follows in cython:

import numpy as np
cimport numpy as np
import cython
from libc.math cimport NAN
from libcpp.unordered_map cimport unordered_map
np.import_array()

@cython.boundscheck(False)
@cython.wraparound(False)
def get_common_element2d(np.ndarray[double, ndim=2] arr1, 
                         np.ndarray[double, ndim=2] arr2):

    cdef np.ndarray[double, ndim=1] result = np.empty(arr1.shape[0])
    cdef int dim1 = arr1.shape[1]
    cdef int dim2 = arr2.shape[1]
    cdef int i, j
    cdef unordered_map[double, int] tmpset = unordered_map[double, int]()

    for i in range(arr1.shape[0]):
        for j in range(dim1):
            # insert arr1[i, j] as key without assigned value
            tmpset[arr1[i, j]]
        for j in range(dim2):
            # check whether arr2[i, j] is in tmpset
            if tmpset.count(arr2[i,j]):
                result[i] = arr2[i,j]
                break
        else:
            result[i] = NAN
        tmpset.clear()

    return result

I have created test cases as follows:

import numpy as np
import timeit
from itertools import starmap
from mycythonmodule import get_common_element2d

m, n = 3000, 3000
a = np.random.rand(m, n)
b = np.random.rand(m, n)

for i, row in enumerate(a):
    if np.random.randint(2):
        common = np.random.choice(row, 1)
        b[i][np.random.choice(np.arange(n), np.random.randint(min(n,20)), False)] = common

# we need to copy the arrays on each test run, otherwise they 
# will remain sorted, which would bias the results

%timeit [set(aa).intersection(bb) for aa, bb in zip(a.copy(), b.copy())]
# returns 3.11 s ± 56.8 ms

%timeit list(starmap(np.intersect1d, zip(a.copy(), b.copy)))
# returns 1.83 s ± 55.4

# test sorting method
# divakarsMethod1 is the appraoch #1 in @Divakar's answer
%timeit divakarsMethod1(a.copy(), b.copy())
# returns 1.88 s ± 18 ms

# test hash map method
%timeit get_common_element2d(a.copy(), b.copy())
# returns 1.46 s ± 22.6 ms

These results seem to indicate that the naive approach is actually better than some vectorized versions. However, the vectorized algorithms play out their strengths, if many rows with fewer columns are considered (a different use case). In these cases, the vectorized approaches are more than 5 times faster than the naive appraoch and the sorting method turns out to be best.

Conclusion: I will go with the HashMap-based cython version, because it is among the most efficient variants in both use cases. If I had to set up cython first, I would use the sorting-based method.

like image 35
Samufi Avatar answered Sep 29 '22 06:09

Samufi