Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does python numpy.where() work?

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():

>>> x = np.arange(9.).reshape(3, 3) >>> np.where( x > 5 ) (array([2, 2, 2]), array([0, 1, 2])) 

How do they achieve internally that you are able to pass something like x > 5 into a method? I guess it has something to do with __gt__ but I am looking for a detailed explanation.

like image 445
pajton Avatar asked Apr 12 '11 22:04

pajton


People also ask

What does where do in NumPy?

where() in Python. The numpy. where() function returns the indices of elements in an input array where the given condition is satisfied.

What is NumPy why and where NumPy is used in Python?

NumPy is a Python library used for working with arrays. It also has functions for working in domain of linear algebra, fourier transform, and matrices. NumPy was created in 2005 by Travis Oliphant. It is an open source project and you can use it freely.

Why does NumPy where return a tuple?

numpy. where returns a tuple because each element of the tuple refers to a dimension. As you can see, the first element of the tuple refers to the first dimension of relevant elements; the second element refers to the second dimension.

How does NumPy array work in Python?

A numpy array is a grid of values, all of the same type, and is indexed by a tuple of nonnegative integers. The number of dimensions is the rank of the array; the shape of an array is a tuple of integers giving the size of the array along each dimension.


1 Answers

How do they achieve internally that you are able to pass something like x > 5 into a method?

The short answer is that they don't.

Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__, __lt__, etc all return boolean arrays where the given condition is true).

E.g.

x = np.arange(9).reshape(3,3) print x > 5 

yields:

array([[False, False, False],        [False, False, False],        [ True,  True,  True]], dtype=bool) 

This is the same reason why something like if x > 5: raises a ValueError if x is a numpy array. It's an array of True/False values, not a single value.

Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5] yields [6 7 8], in this case.

Honestly, it's fairly rare that you actually need numpy.where but it just returns the indicies where a boolean array is True. Usually you can do what you need with simple boolean indexing.

like image 101
Joe Kington Avatar answered Oct 04 '22 15:10

Joe Kington