In numpy if I want to compare two arrays, say for example I want to test if all elements in A are less than values in B, I use if (A < B).all():
. But in practice this requires allocation and evaluation of complete array C = A < B
and then calling C.all()
on it. This is a bit of waste. Is there any way to 'shortcut' the comparison, i.e. directly evaluate A < B
element by element (without allocation and calculation of temporary C
) and stop and return False
when first invalid element comparison is found?
Plain Python and
and or
use shortcut evaluation, but numpy
does not.
(A < B).all()
uses numpy
building blocks, the broadcasting, the element by element comparison with <
and the all
reduction. The <
works just other binary operations, plus, times, and, or, gt, le, etc. And all
is like other reduction methods, any
, max
, sum
, mean
, and can operate on the whole array or by rows or by columns.
It is possible to write a function that combines the all
and <
into one iteration, but it would be difficult to get the generality that I just described.
But if you must implement an iterative solution, with a shortcut action, and do it fast, I'd suggest developing the idea with nditer
, and then compile it with cython
.
http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html is a good tutorial on using nditer
, and it takes you through using it in cython
. nditer
takes care of broadcasting and iteration, letting you concentrate on the comparison and any shortcutting.
Here's a sketch of an iterator that could be cast into cython
:
import numpy as np
a = np.arange(4)[:,None]
b = np.arange(2,5)[None,:]
c = np.array(True)
it = np.nditer([a, b, c], flags=['reduce_ok'],
op_flags = [['readonly'], ['readonly'],['readwrite']])
for x, y, z in it:
z[...] = x<y
if not z:
print('>',x,y)
break
else:
print(x,y)
print(z)
with a sample run:
1420:~/mypy$ python stack34852272.py
(array(0), array(2))
(array(0), array(3))
(array(0), array(4))
(array(1), array(2))
(array(1), array(3))
(array(1), array(4))
('>', array(2), array(2))
False
Start with a default False
, and a different break
condition and you get a shortcutting any
. Generalizing the test to handle <
, <=
, etc will be more work.
Get something like this working in Python, and then try it in Cython. If you have trouble with that step, come back with a new question. SO has a good base of Cython users.
How large are you arrays? I would imagine they are very large, e.g. A.shape = (1000000)
or larger before performance becomes an issue. Would you consider using numpy views?
Instead of comparing (A < B).all()
or (A < B).any()
you can try defining a view, such as (A[:10] < B[:10]).all()
. Here's a simple loop that might work:
k = 0
while( (A[k*10: (k+1)*10] < B[k*10: (k+1)*10] ).all() ):
k += 1
Instead of 10 you can use 100 or 10**3
segment size you wish. Obviously if your segment size is 1, you are saying:
k = 0
while ( A[k] < B[k] ):
k+= 1
Sometimes, comparing the entire array can become memory intensive. If A
and B
have length of 10000 and I need to compare each pair of elements, I am going to run out of space.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With