Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to find indices of condition in numpy array

I'm trying to find the fastest way to to get the functionality of numpy's 'where' statement on a 2D numpy array; namely, retrieving the indices where a condition is met. It is simply much slower than other languages I have used (e.g., IDL, Matlab).

I have cythonized a function that marches through the array in nested for-loops. There is almost an order of magnitude increase in speed, but I would like to increase performance even more, if possible.

TEST.py:

from cython_where import *
import time
import numpy as np

data = np.zeros((2600,5200))
data[100:200,100:200] = 10

t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0

t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1

My cython_where.pyx program:

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython

DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)

def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
  assert data.dtype == DTYPE1

  cdef int xmax = data.shape[0]
  cdef int ymax = data.shape[1]
  cdef unsigned int x, y
  cdef int count = 0
  cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
  cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
  if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
    for x in xrange(xmax):
    for y in xrange(ymax):
      if(data[x,y] == val):
        xind[count] = x
        yind[count] = y
        count += 1

 return tuple([xind[0:count],yind[0:count]]),count

Output of TEST.py: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166

I've also tried numpy's argwhere, which is about as fast as where. I'm pretty new to numpy and cython, so if you have any other ideas to really increase performance, I'm all ears!

like image 377
weather guy Avatar asked Oct 18 '22 17:10

weather guy


1 Answers

Contributions:

  • Numpy can be speed up on flattened array for a 4x gain:

    %timeit np.where(data==10)
    1 loops, best of 3: 105 ms per loop
    
    %timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
    10 loops, best of 3: 26.0 ms per loop
    

I think you can optimize your cython code with that, avoiding computing k=i*ncol+j for each cell.

  • Numba give a simple alternative :

    from numba import jit
    dtype=data.dtype
    @jit(nopython=True)
    def numbaeq(flatdata,x,nrow,ncol):
      size=ncol*nrow
      ix=np.empty(size,dtype=dtype)
      jx=np.empty(size,dtype=dtype)
      count=0
      k=0
      while k<size:
        if flatdata[k]==x :
          ix[count]=k//ncol
          jx[count]=k%ncol
          count+=1
        k+=1          
      return ix[:count],jx[:count]
    
    def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
    

which gives :

    %timeit whereequal(data,10)
    10 loops, best of 3: 20.2 ms per loop

Not great optimisation for numba on such problem, under cython performance.

  • k//ncol and k%ncol can be computed at same time with a optimized divmod operation.
  • ultimate steps are assembly language and parallélisation , but it's other sports.
like image 118
B. M. Avatar answered Oct 31 '22 11:10

B. M.