Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MATLAB-style find() function in Python

In MATLAB it is easy to find the indices of values that meet a particular condition:

>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2)     % find the indecies where this condition is true
[3, 6, 9]          % (MATLAB uses 1-based indexing)
>> a(find(a > 2))  % get the values at those locations
[3, 3, 3]

What would be the best way to do this in Python?

So far, I have come up with the following. To just get the values:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]

But if I want the index of each of those values it's a bit more complicated:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]

Is there a better way to do this in Python, especially for arbitrary conditions (not just 'val > 2')?

I found functions equivalent to MATLAB 'find' in NumPy but I currently do not have access to those libraries.

like image 241
user344226 Avatar asked May 10 '11 23:05

user344226


8 Answers

in numpy you have where :

>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13,  1, 15,  8,  0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
like image 63
joaquin Avatar answered Oct 03 '22 07:10

joaquin


You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)

>>> inds
[2, 5, 8]

It's a little closer to your Matlab example, without having to load up all of numpy.

like image 29
John Avatar answered Oct 03 '22 06:10

John


Or use numpy's nonzero function:

import numpy as np
a    = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds] 
array([3, 4, 5])
like image 29
vincentv Avatar answered Oct 03 '22 07:10

vincentv


Why not just use this:

[i for i in range(len(a)) if a[i] > 2]

or for arbitrary conditions, define a function f for your condition and do:

[i for i in range(len(a)) if f(a[i])]
like image 22
JasonFruit Avatar answered Oct 03 '22 06:10

JasonFruit


The numpy routine more commonly used for this application is numpy.where(); though, I believe it works the same as numpy.nonzero().

import numpy
a    = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)

To get the values, you can either store the indices and slice withe them:

a[inds]

or you can pass the array as an optional parameter:

numpy.where(a>2, a)

or multiple arrays:

b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
like image 28
ryanjdillon Avatar answered Oct 03 '22 07:10

ryanjdillon


To get values with arbitrary conditions, you could use filter() with a lambda function:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]

One possible way to get the indices would be to use enumerate() to build a tuple with both indices and values, and then filter that:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
like image 41
Blair Avatar answered Oct 03 '22 07:10

Blair


I've been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):

a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]

It turns out that this is much faster than:

inds = [i for (i, val) in enumerate(a) if val > 2]

It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.

Edit:

I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:

inds = np.arange( len(a) )[ a < 2 ]
like image 35
Nate Avatar answered Oct 03 '22 06:10

Nate


I think I may have found one quick and simple substitute. BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.

import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a

>> [[ 1.36406736  1.45217257 -0.06896245  0.98429727 -0.59281957]]

idx = mlab.find(a<0)
print idx
type(idx)

>> [2 4]
>> np.ndarray

Best, Da

like image 28
DidasW Avatar answered Oct 03 '22 07:10

DidasW