Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to find row of 2d array in 3d numpy array

I'm trying to find the row in which a 2d array appears in a 3d numpy ndarray. Here's an example of what I mean. Give:

arr = [[[0, 3], [3, 0]],
       [[0, 0], [0, 0]],
       [[3, 3], [3, 3]],
       [[0, 3], [3, 0]]]

I'd like to find all occurrences of:

[[0, 3], [3, 0]]

The result I'd like is:

[0, 3]

I tried to use argwhere but that unfortunately got me nowhere. Any ideas?

like image 566
Shubhang Desai Avatar asked Apr 03 '16 03:04

Shubhang Desai


Video Answer


1 Answers

Try

np.argwhere(np.all(arr==[[0,3], [3,0]], axis=(1,2)))

How it works:

arr == [[0,3], [3,0]] returns

array([[[ True,  True],
        [ True,  True]],

       [[ True, False],
        [False,  True]],

       [[False,  True],
        [ True, False]],

       [[ True,  True],
        [ True,  True]]], dtype=bool)

This is a three dimensional array where the innermost axis is 2. The values at this axis are:

[True, True]
[True, True]
[True, False]
[False, True]
[False, True]
[True, False]
[True, True]
[True, True]

Now with np.all(arr==[[0,3], [3,0]], axis=2) you are checking if both elements on a row are True and its shape will be reduced to (4, 2) from (4, 2, 2). Like this:

array([[ True,  True],
       [False, False],
       [False, False],
       [ True,  True]], dtype=bool)

You need one more step of reducing as you want both of them to be the same (both [0, 3] and [3, 0]. You can do it either by reducing on the result (now the innermost axis is 1):

np.all(np.all(test, axis = 2), axis=1)

Or you can also do it by giving a tuple for the axis parameter to do the same thing step by step (first innermost, then one step higher). The result will be:

array([ True, False, False,  True], dtype=bool)
like image 149
ayhan Avatar answered Nov 14 '22 14:11

ayhan