Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Intersection of sorted numpy arrays

Tags:

numpy

I have a list of sorted numpy arrays. What is the most efficient way to compute the sorted intersection of these arrays?

In my application, I expect the number of arrays to be less than 10^4, I expect the individual arrays to be of length less than 10^7, and I expect the length of the intersection to be close to p*N, where N is the length of the largest array and where 0.99 < p <= 1.0. The arrays are loaded from disk and can be loaded in batches if they won't all fit in memory at once.

A quick and dirty approach is to repeatedly invoke numpy.intersect1d(). That seems inefficient though as intersect1d() does not take advantage of the fact that the arrays are sorted.

like image 675
dshin Avatar asked Oct 04 '17 19:10

dshin


1 Answers

Since intersect1d sort arrays each time, it's effectively inefficient.

Here you have to sweep intersection and each sample together to build the new intersection, which can be done in linear time, maintaining order.

Such task must often be tuned by hand with low level routines.

Here a way to do that with numba :

from numba import njit
import numpy as np

@njit
def drop_missing(intersect,sample):
    i=j=k=0
    new_intersect=np.empty_like(intersect)
    while i< intersect.size and j < sample.size:
            if intersect[i]==sample[j]: # the 99% case
                new_intersect[k]=intersect[i]
                k+=1
                i+=1
                j+=1
            elif intersect[i]<sample[j]:
                i+=1
            else : 
                j+=1
    return new_intersect[:k]  

Now the samples :

n=10**7
ref=np.random.randint(0,n,n)  
ref.sort()

def perturbation(sample,k):
    rands=np.random.randint(0,n,k-1)
    rands.sort()
    l=np.split(sample,rands)
    return np.concatenate([a[:-1] for a in l])

samples=[perturbation(ref,100) for  _ in range(10)] #similar samples 

And a run for 10 samples

def find_intersect(samples):
    intersect=samples[0]
    for sample in samples[1:]:
        intersect=drop_missing(intersect,sample)
    return intersect                

In [18]: %time u=find_intersect(samples)
Wall time: 307 ms

In [19]: len(u)
Out[19]: 9999009     

This way it seems that the job can be done in about 5 minutes , beyond loading time.

like image 197
B. M. Avatar answered Oct 10 '22 02:10

B. M.