Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient way to check high dimensional arrays are overlapped in two ndarray in Python

For example, I have two ndarrays, the shape of train_dataset is (10000, 28, 28) and the shape of val_dateset is (2000, 28, 28).

Except for using iterations, is there any efficient way to use the numpy array functions to find the overlap between two ndarrays?

like image 611
Vicky Avatar asked Dec 15 '22 09:12

Vicky


2 Answers

One trick I learnt from Jaime's excellent answer here is to use an np.void dtype in order to view each row in the input arrays as a single element. This allows you to treat them as 1D arrays, which can then be passed to np.in1d or one of the other set routines.

import numpy as np

def find_overlap(A, B):

    if not A.dtype == B.dtype:
        raise TypeError("A and B must have the same dtype")
    if not A.shape[1:] == B.shape[1:]:
        raise ValueError("the shapes of A and B must be identical apart from "
                         "the row dimension")

    # reshape A and B to 2D arrays. force a copy if neccessary in order to
    # ensure that they are C-contiguous.
    A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
    B = np.ascontiguousarray(B.reshape(B.shape[0], -1))

    # void type that views each row in A and B as a single item
    t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))

    # use in1d to find rows in A that are also in B
    return np.in1d(A.view(t), B.view(t))

For example:

gen = np.random.RandomState(0)

A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]

A_in_B = find_overlap(A, B)

print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True

This method is much more memory-efficient than Divakar's, since it doesn't require broadcasting out to an (m, n, ...) boolean array. In fact, if A and B are row-major then no copying is required at all.


For comparison I've slightly adapted Divakar and B. M.'s solutions.

def divakar(A, B):
    A.shape = A.shape[0], -1
    B.shape = B.shape[0], -1
    return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
    t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
    ma = np.frombuffer(np.ascontiguousarray(A), t)
    mb = np.frombuffer(np.ascontiguousarray(B), t)
    return (mb[:, None] == ma).any(0)

Benchmarks:

In [1]: na = 1000; nb = 200; rowshape = 28, 28

In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
   ....: 
1 loops, best of 3: 244 ms per loop

In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
100 loops, best of 3: 2.81 ms per loop

In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
100 loops, best of 3: 15 ms per loop

As you can see, B. M.'s solution is slightly faster than mine for small n, but np.in1d scales better than testing equality for all elements (O(n log n) rather than O(n²) complexity).

In [5]: na = 10000; nb = 2000; rowshape = 28, 28

In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
1 loops, best of 3: 271 ms per loop

In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
10 loops, best of 3: 123 ms per loop

Divakar's solution is intractable on my laptop for arrays of this size, since it requires generating a 15GB intermediate array whereas I only have 8GB RAM.

like image 86
ali_m Avatar answered Apr 18 '23 19:04

ali_m


Memory permitting you could use broadcasting, like so -

val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]

Sample run -

In [55]: train_dataset
Out[55]: 
array([[[1, 1],
        [1, 1]],

       [[1, 0],
        [0, 0]],

       [[0, 0],
        [0, 1]],

       [[0, 1],
        [0, 0]],

       [[1, 1],
        [1, 0]]])

In [56]: val_dateset
Out[56]: 
array([[[0, 1],
        [1, 0]],

       [[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

In [57]: val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
Out[57]: 
array([[[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

If the elements are integers, you could collapse every block of axis=(1,2) in the input arrays into a scalar assuming them as linearly index-able numbers and then efficiently use np.in1d or np.intersect1d to find the matches.

like image 23
Divakar Avatar answered Apr 18 '23 20:04

Divakar