Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find index positions where 3D-array meets MULTIPLE conditions

I have a 3D-array consisting of several numbers within each band. Is there a function that returns the index positions where the array meets MULTIPLE conditions?

I tried the following:

index_pos = numpy.where(
    array[:,:,0]==10 and array[:,:,1]==15 and array[:,:,2]==30)

It returns the error:

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
like image 343
MoTSCHIGGE Avatar asked Nov 04 '13 14:11

MoTSCHIGGE


1 Answers

You actually have a special case where it would be simpler and more efficient to do the following:

Create the data:

>>> arr
array([[[ 6,  9,  4],
        [ 5,  2,  1],
        [10, 15, 30]],

       [[ 9,  0,  1],
        [ 4,  6,  4],
        [ 8,  3,  9]],

       [[ 6,  7,  4],
        [ 0,  1,  6],
        [ 4,  0,  1]]])

The expected value:

>>> index_pos = np.where((arr[:,:,0]==10) & (arr[:,:,1]==15) & (arr[:,:,2]==30))
>>> index_pos
(array([0]), array([2]))

Use broadcasting to do this simultaneously:

>>> arr == np.array([10,15,30])
array([[[False, False, False],
        [False, False, False],
        [ True,  True,  True]],

       [[False, False, False],
        [False, False, False],
        [False, False, False]],

       [[False, False, False],
        [False, False, False],
        [False, False, False]]], dtype=bool)

>>> np.where( np.all(arr == np.array([10,15,30]), axis=-1) )
(array([0]), array([2]))

If the indices you want are not contiguous you can do something like this:

ind_vals = np.array([0,2])
where_mask = (arr[:,:,ind_vals] == values)

Broadcast when you can.

Spurred by @Jamie's comment, some interesting things to consider:

arr = np.random.randint(0,100,(5000,5000,3))

%timeit np.all(arr == np.array([10,15,30]), axis=-1)
1 loops, best of 3: 614 ms per loop

%timeit ((arr[:,:,0]==10) & (arr[:,:,1]==15) & (arr[:,:,2]==30))
1 loops, best of 3: 217 ms per loop

%timeit tmp = (arr == np.array([10,15,30])); (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
1 loops, best of 3: 368 ms per loop

The question becomes, why is this?:

First off examine:

%timeit (arr[:,:,0]==10)
10 loops, best of 3: 51.2 ms per loop

%timeit (arr == np.array([10,15,30]))
1 loops, best of 3: 300 ms per loop

One would expect that arr == np.array([10,15,30]) would be at worse case 1/3 the speed of arr[:,:,0]==10. Anyone have an idea why this is not the case?

Then when combining the final axis there are many ways to accomplish this.

tmp = (arr == np.array([10,15,30]))

method1 = np.all(tmp,axis=-1)
method2 = (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
method3 = np.einsum('ij,ij,ij->ij',tmp[:,:,0] , tmp[:,:,1] , tmp[:,:,2])

np.allclose(method1,method2)
True
np.allclose(method1,method3)
True

%timeit np.all(tmp,axis=-1)
1 loops, best of 3: 318 ms per loop

%timeit (tmp[:,:,0] & tmp[:,:,1] & tmp[:,:,2])
10 loops, best of 3: 68.2 ms per loop

%timeit np.einsum('ij,ij,ij->ij',tmp[:,:,0] , tmp[:,:,1] , tmp[:,:,2])
10 loops, best of 3: 38 ms per loop

The einsum speed up is well defined elsewhere, but it seems odd to me that there is such a difference between all and consecutive &'s.

like image 145
Daniel Avatar answered Oct 13 '22 22:10

Daniel