Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Indexing a numpy array with a numpy array of indexes [duplicate]

I have a 3D numpy array data and another array pos of indexes (an index is a numpy array on its own, which makes the latter array a 2D array):

import numpy as np
data = np.arange(8).reshape(2, 2, -1)
#array([[[0, 1],
#    [2, 3]],
#
#  [[4, 5],
#    [6, 7]]])

pos = np.array([[1, 1, 0], [0, 1, 0], [1, 0, 0]])
#array([[1, 1, 0],
#       [0, 1, 0],
#       [1, 0, 0]])

I want to select and/or mutate the elements from data using the indexes from pos. I can do the selection using a for loop or a list comprehension:

[data[tuple(i)] for i in pos]
#[6, 2, 4]
data[[i for i in pos.T]]
#array([6, 2, 4])

But this does not seem to be a numpy way. Is there a vectorized numpy solution to this problem?

like image 857
DYZ Avatar asked May 16 '18 03:05

DYZ


People also ask

Does indexing a NumPy array create a copy?

Producing a View of an Array As stated above, using basic indexing does not return a copy of the data being accessed, rather it produces a view of the underlying data. NumPy provides the function numpy.

How does NumPy array indexing work?

Array indexing is the same as accessing an array element. You can access an array element by referring to its index number. The indexes in NumPy arrays start with 0, meaning that the first element has index 0, and the second has index 1 etc.

Does += work with NumPy arrays?

Numpy arrays are mutable objects that have clearly defined in place operations. If a and b are arrays of the same shape, a += b adds the two arrays together, using a as an output buffer.

How do you repeat an array in NumPy?

NumPy: repeat() function The repeat() function is used to repeat elements of an array. Input array. The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.


1 Answers

You can split pos into 3 separate arrays and index, like so—

>>> i, j, k = pos.T
>>> data[i, j, k]
array([6, 2, 4])

Here, the number of columns in pos correspond to the depth of data. As long as you're dealing with 3D matrices, getting i, j, and k well never get more complicated than this.

On python-3.6+, you can shorten this to—

>>> data[[*pos.T]]
array([6, 2, 4])
like image 184
cs95 Avatar answered Sep 22 '22 16:09

cs95