I'm trying to find the fastest way to to get the functionality of numpy's 'where' statement on a 2D numpy array; namely, retrieving the indices where a condition is met. It is simply much slower than other languages I have used (e.g., IDL, Matlab).
I have cythonized a function that marches through the array in nested for-loops. There is almost an order of magnitude increase in speed, but I would like to increase performance even more, if possible.
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
My cython_where.pyx program:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
Output of TEST.py:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
I've also tried numpy's argwhere, which is about as fast as where. I'm pretty new to numpy and cython, so if you have any other ideas to really increase performance, I'm all ears!
Contributions:
Numpy can be speed up on flattened array for a 4x gain:
%timeit np.where(data==10)
1 loops, best of 3: 105 ms per loop
%timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
10 loops, best of 3: 26.0 ms per loop
I think you can optimize your cython code with that, avoiding computing k=i*ncol+j for each cell.
Numba give a simple alternative :
from numba import jit
dtype=data.dtype
@jit(nopython=True)
def numbaeq(flatdata,x,nrow,ncol):
size=ncol*nrow
ix=np.empty(size,dtype=dtype)
jx=np.empty(size,dtype=dtype)
count=0
k=0
while k<size:
if flatdata[k]==x :
ix[count]=k//ncol
jx[count]=k%ncol
count+=1
k+=1
return ix[:count],jx[:count]
def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
which gives :
%timeit whereequal(data,10)
10 loops, best of 3: 20.2 ms per loop
Not great optimisation for numba on such problem, under cython performance.
k//ncol and k%ncol can be computed at same time with a optimized divmod operation. If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With