Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding first non-zero value along axis of a sorted two dimensional numpy array

I'm trying to find the fastest way to find the first non-zero value for each row of a two dimensional sorted array. Technically, the only values in the array are zeros and ones, and it is "sorted".

For instance, the array could look like the following:

v =

0 0 0 1 1 1 1 
0 0 0 1 1 1 1 
0 0 0 0 1 1 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 0

I could use the argmax function

argmax(v, axis=1))

to find when it changes from zero to one, but I believe this would do an exhaustive search along each row. My array will be reasonably sized (~2000x2000). Would argmax still outperform just doing a searchsorted approach for each row within a for loop, or is there a better alternative?

Also, the array will always be such that the first position of a one for a row is always >= the first position of a one in the row above it (but it is not guaranteed that there will be a one in the last few rows). I could exploit this with a for loop and a "starting index value" for each row equal to the position of the first 1 from the previous row, but am i correct in thinking that the numpy argmax function will still outperform a loop written in python.

I would just benchmark the alternatives, but the edge length of the array could change quite a bit (from 250 to 10,000).

like image 452
user1554752 Avatar asked Jul 31 '12 00:07

user1554752


2 Answers

It is reasonably fast to use np.where:

>>> a
array([[0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0]])
>>> np.where(a>0)
(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5]), array([3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 6, 6, 6]))

That delivers tuples with to coordinates of the values greater than 0.

You can also use np.where to test each sub array:

def first_true1(a):
    """ return a dict of row: index with value in row > 0 """
    di={}
    for i in range(len(a)):
        idx=np.where(a[i]>0)
        try:
            di[i]=idx[0][0]
        except IndexError:
            di[i]=None    

    return di       

Prints:

{0: 3, 1: 3, 2: 4, 3: 6, 4: 6, 5: 6, 6: None}

ie, row 0: index 3>0; row 4: index 4>0; row 6: no index greater than 0

As you suspect, argmax may be faster:

def first_true2():
    di={}
    for i in range(len(a)):
        idx=np.argmax(a[i])
        if idx>0:
            di[i]=idx
        else:
            di[i]=None    

    return di       
    # same dict is returned...

If you can deal with the logic of not having a None for rows of all naughts, this is faster still:

def first_true3():
    di={}
    for i, j in zip(*np.where(a>0)):
        if i in di:
            continue
        else:
            di[i]=j

    return di      

And here is a version that uses axis in argmax (as suggested in your comments):

def first_true4():
    di={}
    for i, ele in enumerate(np.argmax(a,axis=1)):
        if ele==0 and a[i][0]==0:
            di[i]=None
        else:
            di[i]=ele

    return di          

For speed comparisons (on your example array), I get this:

            rate/sec usec/pass first_true1 first_true2 first_true3 first_true4
first_true1   23,818    41.986          --      -34.5%      -63.1%      -70.0%
first_true2   36,377    27.490       52.7%          --      -43.6%      -54.1%
first_true3   64,528    15.497      170.9%       77.4%          --      -18.6%
first_true4   79,287    12.612      232.9%      118.0%       22.9%          --

If I scale that to a 2000 X 2000 np array, here is what I get:

            rate/sec  usec/pass first_true3 first_true1 first_true2 first_true4
first_true3        3 354380.107          --       -0.3%      -74.7%      -87.8%
first_true1        3 353327.084        0.3%          --      -74.6%      -87.7%
first_true2       11  89754.200      294.8%      293.7%          --      -51.7%
first_true4       23  43306.494      718.3%      715.9%      107.3%          --
like image 127
dawg Avatar answered Oct 29 '22 15:10

dawg


argmax() use C level loop, it's much faster than Python loop, so I think even you write a smart algorithm in Python, it's hard to beat argmax(), You can use Cython to speedup:

@cython.boundscheck(False)
@cython.wraparound(False) 
def find(int[:,:] a):
    cdef int h = a.shape[0]
    cdef int w = a.shape[1]
    cdef int i, j
    cdef int idx = 0
    cdef list r = []
    for i in range(h):
        for j in range(idx, w):
            if a[i, j] == 1:
                idx = j
                r.append(idx)
                break
        else:
            r.append(-1)
    return r

On my PC for 2000x2000 matrix, it's 100us vs 3ms.

like image 45
HYRY Avatar answered Oct 29 '22 15:10

HYRY