I'm trying to find the fastest way to to get the functionality of numpy's 'where' statement on a 2D numpy array; namely, retrieving the indices where a condition is met. It is simply much slower than other languages I have used (e.g., IDL, Matlab).
I have cythonized a function that marches through the array in nested for-loops. There is almost an order of magnitude increase in speed, but I would like to increase performance even more, if possible.
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
My cython_where.pyx program:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
Output of TEST.py:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
I've also tried numpy's argwhere
, which is about as fast as where
. I'm pretty new to numpy and cython, so if you have any other ideas to really increase performance, I'm all ears!
Contributions:
Numpy can be speed up on flattened array for a 4x gain:
%timeit np.where(data==10)
1 loops, best of 3: 105 ms per loop
%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
10 loops, best of 3: 26.0 ms per loop
I think you can optimize your cython code with that, avoiding computing k=i*ncol+j
for each cell.
Numba give a simple alternative :
from numba import jit
dtype=data.dtype
@jit(nopython=True)
def numbaeq(flatdata,x,nrow,ncol):
size=ncol*nrow
ix=np.empty(size,dtype=dtype)
jx=np.empty(size,dtype=dtype)
count=0
k=0
while k<size:
if flatdata[k]==x :
ix[count]=k//ncol
jx[count]=k%ncol
count+=1
k+=1
return ix[:count],jx[:count]
def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
which gives :
%timeit whereequal(data,10)
10 loops, best of 3: 20.2 ms per loop
Not great optimisation for numba on such problem, under cython performance.
k//ncol
and k%ncol
can be computed at same time with a optimized divmod
operation. If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With