Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy, Select elements in rows by 1d indexes array

Tags:

python

numpy

Let we have square array, n*n. For example, n=3 and array is this:

arr = array([[0, 1, 2],
   [3, 4, 5],
   [6, 7, 8]])

And let we have array of indices in the each ROW. For example:

myidx=array([1, 2, 1], dtype=int64)

I want to get:

[1, 5, 7]

Because in line [0,1,2] take element with index 1, in line [3,4,5] get element with index 2, in line [6,7,8] get element with index 1.

I'm confused, and can't take elements this way using standard numpy indexing. Thank you for answer.

like image 890
dondublon Avatar asked Dec 25 '22 04:12

dondublon


2 Answers

There's no real pretty way but this does what you are looking for :)

In [1]: from numpy import *

In [2]: arr = array([[0, 1, 2],
   [3, 4, 5],
   [6, 7, 8]])

In [3]: myidx = array([1, 2, 1], dtype=int64)

In [4]: arr[arange(len(myidx)), myidx]
Out[4]: array([1, 5, 7])
like image 153
Wolph Avatar answered Dec 28 '22 13:12

Wolph


Simpler way to reach the goal is using choose numpy function:

numpy.choose(myidx, arr.transpose())
like image 24
belgraviton Avatar answered Dec 28 '22 13:12

belgraviton