Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get (column, row) index from NumPy array that meets a boolean condition

I am working with a 2D NumPy array. I would like to get the (column, row) index, or (x, y) coordinate, if you prefer thinking that way, from my 2D array that meets a boolean condition.

The best way I can explain what I am trying to do is via a trivial example:

>>> a = np.arange(9).reshape(3, 3)
>>> b = a > 4
>>> b
>>> array([[False, False, False],
           [False, False,  True],
           [ True,  True,  True]], dtype=bool)

At this point I now have a boolean array, indicating where a > 4.

My goal at this point is grab the indexes of the boolean array where the value is True. For example, the indexes (1, 2), (2, 0), (2, 1), and (2, 2) all have a value of True.

My end goal is to end up with a list of indexes:

>>> indexes = [(1, 2), (2, 0), (2, 1), (2, 2)]

Again, I stress the point that the code above is a trivial example, but the application of what I'm trying to do could have arbitrary indexes where a > 4 and not something based on arange and reshape.

like image 923
Adrian Rosebrock Avatar asked Jan 28 '14 14:01

Adrian Rosebrock


2 Answers

Use numpy.where with numpy.column_stack:

>>> np.column_stack(np.where(b))
array([[1, 2],
       [2, 0],
       [2, 1],
       [2, 2]])
like image 189
Ashwini Chaudhary Avatar answered Oct 21 '22 22:10

Ashwini Chaudhary


An alternative to the answer of @Ashwini Chaudhary, is numpy.nonzero

>>> a = np.arange(9).reshape(3,3)
>>> b = a > 4
>>> np.nonzero(b)
(array([1, 2, 2, 2]), array([2, 0, 1, 2]))

>>> np.transpose(np.nonzero(b))
array([[1, 2],
       [2, 0],
       [2, 1],
       [2, 2]])

EDIT: What is faster. nonzero and where are essentially equivalent, but transpose turns out to be the wrong one here (even though it's mentioned in the docs):

In [15]: N = 5000

In [16]: a = np.random.random((N, N))

In [17]: %timeit np.nonzero(a > 0.5)
1 loops, best of 3: 470 ms per loop

In [18]: %timeit np.transpose(np.nonzero(a > 0.5))     # ooops
1 loops, best of 3: 2.56 s per loop

In [19]: %timeit np.where(a > 0.5)
1 loops, best of 3: 467 ms per loop

In [20]: %timeit np.column_stack(np.where(a > 0.5))
1 loops, best of 3: 653 ms per loop
like image 36
ev-br Avatar answered Oct 21 '22 21:10

ev-br