Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy - Indexing one dimension of a multidimensional array

I have an numpy array like this with shape (6, 2, 4):

x = np.array([[[0, 3, 2, 0],
               [1, 3, 1, 1]],

              [[3, 2, 3, 3],
               [0, 3, 2, 0]],

              [[1, 0, 3, 1],
               [3, 2, 3, 3]],

              [[0, 3, 2, 0],
               [1, 3, 2, 2]],

              [[3, 0, 3, 1],
               [1, 0, 1, 1]],

              [[1, 3, 1, 1],
               [3, 1, 3, 3]]])

And I have choices array like this:

choices = np.array([[1, 1, 1, 1],
                    [0, 1, 1, 0],
                    [1, 1, 1, 1],
                    [1, 0, 0, 0],
                    [1, 0, 1, 1],
                    [0, 0, 0, 1]])

How can I use choices array to index only the middle dimension with size 2 and get a new numpy array with shape (6, 4) in the most efficient way possible?

The result would be this:

[[1 3 1 1]
 [3 3 2 3]
 [3 2 3 3]
 [1 3 2 0]
 [1 0 1 1]
 [1 3 1 3]]

I've tried to do it by x[:, choices, :] but this doesn't return what I want. I also tried to do x.take(choices, axis=1) but no luck.

like image 796
Mustafa Süve Avatar asked Oct 14 '25 20:10

Mustafa Süve


1 Answers

Use np.take_along_axis to index along the second axis -

In [16]: np.take_along_axis(x,choices[:,None],axis=1)[:,0]
Out[16]: 
array([[1, 3, 1, 1],
       [3, 3, 2, 3],
       [3, 2, 3, 3],
       [1, 3, 2, 0],
       [1, 0, 1, 1],
       [1, 3, 1, 3]])

Or with explicit integer-array indexing -

In [22]: m,n = choices.shape

In [23]: x[np.arange(m)[:,None],choices,np.arange(n)]
Out[23]: 
array([[1, 3, 1, 1],
       [3, 3, 2, 3],
       [3, 2, 3, 3],
       [1, 3, 2, 0],
       [1, 0, 1, 1],
       [1, 3, 1, 3]])
like image 122
Divakar Avatar answered Oct 17 '25 09:10

Divakar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!