Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Where can I find numpy.where() source code? [duplicate]

Tags:

python

numpy

I have already found the source for the numpy.ma.where() function but it seems to be calling the numpy.where() function and to better understand it I would like to take a look if possible.

like image 812
usr48 Avatar asked Feb 03 '19 02:02

usr48


People also ask

What is NumPy where () in Python?

numpy.where() in Python. numpy.where(condition[, x, y]) function returns the indices of elements in an input array where the given condition is satisfied.

How can I get the source code of a python function?

You can get the source code of pure python modules that are part of the standard library from the location where Python is installed. For example at : C:\Python27\Lib (on windows) if you have used Windows Installer for Python Installation How can I find the source code of python built in function like len?

How does NumPy's X and Y version work?

The condition, x, y version takes three arrays, which it broadcasts against each other. The return array has the common broadcasted shape, with elements chosen from x and y as explained in the answers to your previous question, How exactly does numpy.where () select the elements in this example?

How to return the index of an array in NumPy?

numpy.where(condition[, x, y]) function returns the indices of elements in an input array where the given condition is satisfied. Parameters: condition : When True, yield x, otherwise yield y.


2 Answers

Most Python functions are written in the Python language, but some functions are written in something more native (often the C language).

Regular Python functions ("pure Python")

There are a few techniques you can use to ask Python itself where a function is defined. Probably the most portable uses the inspect module:

>>> import numpy
>>> import inspect
>>> inspect.isbuiltin(numpy.ma.where)
False
>>> inspect.getsourcefile(numpy.ma.where)
'.../numpy/core/multiarray.py'

But this won't work with a native ("built-in") function:

>>> import numpy
>>> import inspect
>>> inspect.isbuiltin(numpy.where)
True
>>> inspect.getsourcefile(numpy.where)
TypeError: <built-in function where> is not a module, class, method, function, traceback, frame, or code object

Native ("built-in") functions

Unfortunately, Python doesn't provide a record of source files for built-in functions. You can find out which module provides the function:

>>> import numpy as np
>>> np.where
<built-in function where>
>>> np.where.__module__
'numpy.core.multiarray'

Python won't help you find the native (C) source code for that module, but in this case it's reasonable to look in the numpy project for C source that has similar names. I found the following file:

numpy/core/src/multiarray/multiarraymodule.c

And in that file, I found a list of definitions (PyMethodDef) including:

    {"where",
        (PyCFunction)array_where,
        METH_VARARGS, NULL},

This suggests that the C function array_where is the one that Python sees as "where".

The array_where function is defined in the same file, and it mostly delegates to the PyArray_Where function.

In short

NumPy's np.where function is written in C, not Python. A good place to look is PyArray_Where.

like image 154
RJHunter Avatar answered Oct 12 '22 13:10

RJHunter


First there are 2 distinct versions of where, one that takes just the condition, the other that takes 3 arrays.

The simpler one is most commonly used, and is just another name for np.nonzero. This scans through the condition array twice. Once with np.count_nonzero to determine how many nonzero entries there are, which allows it to allocate the return arrays. The second step is to fill in the coordinates of all nonzero entries. The key is that it returns a tuple of arrays, one array for each dimension of condition.

The condition, x, y version takes three arrays, which it broadcasts against each other. The return array has the common broadcasted shape, with elements chosen from x and y as explained in the answers to your previous question, How exactly does numpy.where() select the elements in this example?

You do realize that most of this code is c or cython, with a significant about of preprocessing. It is hard to read, even for experienced users. It is easier to run a variety of test cases and get a feel for what is happening that way.

A couple things to watch out for. np.where is a python function, and python evaluates each input fully before passing them to it. This is conditional assignment, not conditional evaluation function.

And unless you pass 3 arrays that match in shape, or scalar x and y, you'll need a good understanding of broadcasting.

like image 41
hpaulj Avatar answered Oct 12 '22 11:10

hpaulj