In MATLAB it is easy to find the indices of values that meet a particular condition:
>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2) % find the indecies where this condition is true
[3, 6, 9] % (MATLAB uses 1-based indexing)
>> a(find(a > 2)) % get the values at those locations
[3, 3, 3]
What would be the best way to do this in Python?
So far, I have come up with the following. To just get the values:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]
But if I want the index of each of those values it's a bit more complicated:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]
Is there a better way to do this in Python, especially for arbitrary conditions (not just 'val > 2')?
I found functions equivalent to MATLAB 'find' in NumPy but I currently do not have access to those libraries.
in numpy you have where
:
>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13, 1, 15, 8, 0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
a = [1, 2, 3, 1, 2, 3, 1, 2, 3]
inds = indices(a, lambda x: x > 2)
>>> inds
[2, 5, 8]
It's a little closer to your Matlab example, without having to load up all of numpy.
Or use numpy's nonzero function:
import numpy as np
a = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds]
array([3, 4, 5])
Why not just use this:
[i for i in range(len(a)) if a[i] > 2]
or for arbitrary conditions, define a function f
for your condition and do:
[i for i in range(len(a)) if f(a[i])]
The numpy
routine more commonly used for this application is numpy.where()
; though, I believe it works the same as numpy.nonzero()
.
import numpy
a = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)
To get the values, you can either store the indices and slice withe them:
a[inds]
or you can pass the array as an optional parameter:
numpy.where(a>2, a)
or multiple arrays:
b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
To get values with arbitrary conditions, you could use filter()
with a lambda function:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]
One possible way to get the indices would be to use enumerate()
to build a tuple with both indices and values, and then filter that:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
I've been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):
a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]
It turns out that this is much faster than:
inds = [i for (i, val) in enumerate(a) if val > 2]
It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.
Edit:
I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:
inds = np.arange( len(a) )[ a < 2 ]
I think I may have found one quick and simple substitute. BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.
import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a
>> [[ 1.36406736 1.45217257 -0.06896245 0.98429727 -0.59281957]]
idx = mlab.find(a<0)
print idx
type(idx)
>> [2 4]
>> np.ndarray
Best, Da
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With