Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding about the numpy.where

Tags:

python

numpy

I am reading the numpy.where(condition[, x, y]) documentation, but I can not understand the small example:

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
Out: (array([2, 2, 2]), array([0, 1, 2]))

Can some one explain how the result comes?

like image 873
xirururu Avatar asked Feb 11 '23 09:02

xirururu


2 Answers

The first array (array([2, 2, 2])) is the index of rows and the second (array([0, 1, 2])) is the columns of those values that are more than 5.

You can use zip to get the exact index of values :

>>> zip(*np.where( x > 5 ))
[(2, 0), (2, 1), (2, 2)]

Or use np.dstack :

>>> np.dstack(np.where( x > 5 ))
array([[[2, 0],
        [2, 1],
        [2, 2]]])
like image 99
Mazdak Avatar answered Feb 13 '23 02:02

Mazdak


It's printing out the coordinates to your condition

import numpy as np

x = np.arange(9.).reshape(3, 3)
print x
print np.where( x > 5 )

where print x prints:

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]]

and np.where( x > 5 ) prints the index location of all elements greater than 5

(array([2, 2, 2]), array([0, 1, 2]))

where 2,0 == 6 and 2,1 == 7 and 2,2 == 8

like image 40
reticentroot Avatar answered Feb 13 '23 02:02

reticentroot