Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Map numpy's `in1d` over 2D array

I have two 2D numpy arrays,

import numpy as np    
a = np.array([[  1,  15,  16, 200,  10],
              [ -1,  10,  17,  11,  -1],
              [ -1,  -1,  20,  -1,  -1]])

g = np.array([[  1,  12,  15, 100,  11],
              [  2,  13,  16, 200,  12],
              [  3,  14,  17, 300,  13],
              [  4,  17,  18, 400,  14],
              [  5,  20,  19, 500,  16]])

What I want to do is, for each column of g, to check if it contains any element from the corresponding column of a. For the first column, I want to check if any of the values [1,2,3,4,5] appears in [1,-1,-1] and return True. For the second, I want to return False because no element in [12,13,14,17,20] appears in [15,10,-1]. At the moment, I do this using Python's list comprehension. Running

 result = [np.any(np.in1d(g[:,i], a[:, i])) for i in range(5)]

calculates the correct result, but is getting slow when a has a lot of columns. Is there a more "pure numpy" way of doing this same thing? I feel like there should be an axis keyword one could add to the numpy.in1d function, but there isn't any...

like image 930
5xum Avatar asked Mar 05 '26 08:03

5xum


2 Answers

I'd use broadcasting tricks, but this depends very much on the size of your arrays and the amount of RAM available to you:

M = g.reshape(g.shape+(1,)) - a.T.reshape((1,a.shape[1],a.shape[0]))
np.any(np.any(M == 0, axis=0), axis=1)
# returns:
# array([ True, False,  True,  True, False], dtype=bool)

It's easier to explain with a piece of paper and a pen (and smaller test arrays) (see below), but basically you're making copies of each column in g (one copy for each row in a) and subtracting single elements taken from the corresponding column in a from these copies. Similar to the original algorithm, just vectorized.

Caveat: if any of the arrays g or a is 1D, you'll need to force it to become 2D, such that its shape is at least (1,n).

Speed gains:

  • based only on your arrays: a factor ~20

    • python for loops: 301us per loop
    • vectorized: 15.4us per loop
  • larger arrays: factor ~80

    In [2]: a = np.random.random_integers(-2, 3, size=(4, 50))
    
    In [3]: b = np.random.random_integers(-20, 30, size=(35, 50))
    
    In [4]: %timeit np.any(np.any(b.reshape(b.shape+(1,)) - a.T.reshape((1,a.shape[1],a.shape[0])) == 0, axis=0), axis=1)
    10000 loops, best of 3: 39.5 us per loop
    
    In [5]: %timeit [np.any(np.in1d(b[:,i], a[:, i])) for i in range(a.shape[1])]
    100 loops, best of 3: 3.13 ms per loop
    

Image attached to explain broadcasting: broadcasting explained

like image 81
Oliver W. Avatar answered Mar 06 '26 21:03

Oliver W.


Instead of processing the input by column, you can process it by rows. For example you find out if any element of the first row of a is present in the columns of g, so that you can stop processing the columns where the element is found.

idx = arange(a.shape[1])
result = empty((idx.size,), dtype=bool)
result.fill(False)

for j in range(a.shape[0]):
    #delete this print in production
    print "%d line, I look only at columns " % (j + 1), idx
    line_pruned = take(a[j], idx)
    g_pruned = take(g, idx, axis=1)
    positive_idx = where((g_pruned - line_pruned) == 0)[1]
    #delete this print in production
    print "positive hit on the ", positive_idx, " -th columns"
    put(result, positive_idx, True)
    idx = setdiff1d(idx, positive_idx)
    if not idx.size:
        break

To understand how it works, we can consider a different input:

a = np.array([[  0,  15,  16, 200,  10],
              [ -1,  10,  17,  11,  -1],
              [  1,  -1,  20,  -1,  -1]])

g = np.array([[  1,  12,  15, 100,  11],
              [  2,  13,  16, 200,  12],
              [  3,  14,  17, 300,  13],
              [  4,  17,  18, 400,  14],
              [  5,  20,  19, 500,  16]])

The output of the script is:

1 line, I look only at columns  [0 1 2 3 4]
positive hit on the  [2 3]  -th columns
2 line, I look only at columns  [0 1 4]
positive hit on the  []  -th columns
3 line, I look only at columns  [0 1 4]
positive hit on the  [0]  -th columns

Basically you can see how in the 2nd and 3rd round of the loop you're not processing the 2nd and 4th column.

The performance of this solution really depends on many factors, but it will be faster if it is likely that you hit many True values, and the problem has many rows. This of course depends also on the input, not just on the shape.

like image 28
gg349 Avatar answered Mar 06 '26 20:03

gg349