Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to take elements along a given axis, given by their indices?

I have a 3D array and I need to "squeeze" it over the last axis, so that I get a 2D array. I need to do it in the following way. For each values of the indices for the first two dimensions I know the value of the index for the 3rd dimension from where the value should be taken.

For example, I know that if i1 == 2 and i2 == 7 then i3 == 11. It means that out[2,7] = inp[2,7,11]. This mapping from first two dimensions into the third one is given in another 2D array. In other words, I have an array in which on the position 2,7 I have 11 as a value.

So, my question is how to combine these two array (3D and 2D) to get the output array (2D).

like image 383
Roman Avatar asked Jul 02 '17 07:07

Roman


2 Answers

In [635]: arr = np.arange(24).reshape(2,3,4)
In [636]: idx = np.array([[1,2,3],[0,1,2]])


In [637]: I,J = np.ogrid[:2,:3]
In [638]: arr[I,J,idx]
Out[638]: 
array([[ 1,  6, 11],
       [12, 17, 22]])
In [639]: arr
Out[639]: 
array([[[ 0,  1,  2,  3],   # 1
        [ 4,  5,  6,  7],   # 6
        [ 8,  9, 10, 11]],  # ll

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

I,J broadcast together to select a (2,3) set of values, matching idx:

In [640]: I
Out[640]: 
array([[0],
       [1]])
In [641]: J
Out[641]: array([[0, 1, 2]])

This is a generalization to 3d of the easier 2d problem - selecting one item from each row:

In [649]: idx
Out[649]: 
array([[1, 2, 3],
       [0, 1, 2]])
In [650]: idx[np.arange(2), [0,1]]
Out[650]: array([1, 1])

In fact we could convert the 3d problem into a 2d one:

In [655]: arr.reshape(6,4)[np.arange(6), idx.ravel()]
Out[655]: array([ 1,  6, 11, 12, 17, 22])

Generalizing the original case:

In [55]: arr = np.arange(24).reshape(2,3,4)                                     
In [56]: idx = np.array([[1,2,3],[0,1,2]])                                      
In [57]: IJ = np.ogrid[[slice(i) for i in idx.shape]]                           
In [58]: IJ                                                                     
Out[58]: 
[array([[0],
        [1]]), array([[0, 1, 2]])]
In [59]: (*IJ,idx)                                                              
Out[59]: 
(array([[0],
        [1]]), array([[0, 1, 2]]), array([[1, 2, 3],
        [0, 1, 2]]))
In [60]: arr[_]                                                                 
Out[60]: 
array([[ 1,  6, 11],
       [12, 17, 22]])

The key is in combining the IJ list of arrays with the idx to make a new indexing tuple. Constructing the tuple is a little messier if idx isn't the last index, but it's still possible. E.g.

In [61]: (*IJ[:-1],idx,IJ[-1])                                                  
Out[61]: 
(array([[0],
        [1]]), array([[1, 2, 3],
        [0, 1, 2]]), array([[0, 1, 2]]))
In [62]: arr.transpose(0,2,1)[_]                                                
Out[62]: 
array([[ 1,  6, 11],
       [12, 17, 22]])

Of if it's easier transpose arr to the idx dimension is last. The key is that the index operation takes a tuple of index arrays, arrays which broadcast against each other to select specific items. That's what ogrid is doing, create the arrays that work with idx.

like image 119
hpaulj Avatar answered Oct 15 '22 21:10

hpaulj


inp = np.random.random((20, 10, 5)) # simulate some input
i1, i2 = np.indices(inp.shape[:2])
i3 = np.random.randint(0, 5, size=inp.shape) # or implement whatever mapping
                                             # you want between (i1,i2) and i3
out = inp[(i1, i2, i3)]

See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing for more details

like image 31
AGN Gazer Avatar answered Oct 15 '22 21:10

AGN Gazer