Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

shortcut evaluation of numpy's array comparison

Tags:

python

numpy

In numpy if I want to compare two arrays, say for example I want to test if all elements in A are less than values in B, I use if (A < B).all():. But in practice this requires allocation and evaluation of complete array C = A < B and then calling C.all() on it. This is a bit of waste. Is there any way to 'shortcut' the comparison, i.e. directly evaluate A < B element by element (without allocation and calculation of temporary C) and stop and return False when first invalid element comparison is found?

like image 465
V.K. Avatar asked Jan 18 '16 10:01

V.K.


2 Answers

Plain Python and and or use shortcut evaluation, but numpy does not.

(A < B).all()

uses numpy building blocks, the broadcasting, the element by element comparison with < and the all reduction. The < works just other binary operations, plus, times, and, or, gt, le, etc. And all is like other reduction methods, any, max, sum, mean, and can operate on the whole array or by rows or by columns.

It is possible to write a function that combines the all and < into one iteration, but it would be difficult to get the generality that I just described.

But if you must implement an iterative solution, with a shortcut action, and do it fast, I'd suggest developing the idea with nditer, and then compile it with cython.

http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html is a good tutorial on using nditer, and it takes you through using it in cython. nditer takes care of broadcasting and iteration, letting you concentrate on the comparison and any shortcutting.

Here's a sketch of an iterator that could be cast into cython:

import numpy as np

a = np.arange(4)[:,None]
b = np.arange(2,5)[None,:]
c = np.array(True)
it = np.nditer([a, b, c], flags=['reduce_ok'],
    op_flags = [['readonly'], ['readonly'],['readwrite']])
for x, y, z in it:
    z[...] = x<y
    if not z:
        print('>',x,y)
        break
    else:
        print(x,y)
print(z)

with a sample run:

1420:~/mypy$ python stack34852272.py 
(array(0), array(2))
(array(0), array(3))
(array(0), array(4))
(array(1), array(2))
(array(1), array(3))
(array(1), array(4))
('>', array(2), array(2))
False

Start with a default False, and a different break condition and you get a shortcutting any. Generalizing the test to handle <, <=, etc will be more work.

Get something like this working in Python, and then try it in Cython. If you have trouble with that step, come back with a new question. SO has a good base of Cython users.

like image 112
hpaulj Avatar answered Sep 21 '22 12:09

hpaulj


How large are you arrays? I would imagine they are very large, e.g. A.shape = (1000000) or larger before performance becomes an issue. Would you consider using numpy views?

Instead of comparing (A < B).all() or (A < B).any() you can try defining a view, such as (A[:10] < B[:10]).all(). Here's a simple loop that might work:

k = 0
while( (A[k*10: (k+1)*10] < B[k*10: (k+1)*10] ).all() ):
    k += 1

Instead of 10 you can use 100 or 10**3 segment size you wish. Obviously if your segment size is 1, you are saying:

k = 0
while ( A[k] < B[k] ):
    k+= 1

Sometimes, comparing the entire array can become memory intensive. If A and B have length of 10000 and I need to compare each pair of elements, I am going to run out of space.

like image 26
john mangual Avatar answered Sep 20 '22 12:09

john mangual