Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy array slicing, get one from each third dimension

I have a 3D array of data. I have a 2D array of indices, where the shape matches the first two dimensions of the data array, and it specfies the indices I want to pluck from the data array to make a 2D array. eg:

 from numpy import *
 a = arange(3 * 5 * 7).reshape((3,5,7))
 getters = array([0,1,2] * (5)).reshape(3,5)

What I'm looking for is a syntax like a[:, :, getters] which returns an array of shape (3,5) by indexing independently into the third dimension of each item. However, a[:, :, getters] returns an array of shape (3,5,3,5). I can do it by iterating and building a new array, but this is pretty slow:

 array([[col[getters[ri,ci]] for ci,col in enumerate(row)] for ri,row in enumerate(a)])
 # gives array([[  0,   8,  16,  21,  29],
 #    [ 37,  42,  50,  58,  63],
 #    [ 71,  79,  84,  92, 100]])

Is there a neat+fast way?

like image 683
Dan Stowell Avatar asked Nov 30 '25 18:11

Dan Stowell


1 Answers

If I understand you correctly, I've done something like this using fancy indexing:

>>> k,j = np.meshgrid(np.arange(a.shape[1]),np.arange(a.shape[0]))
>>> k
array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]])
>>> j
array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2]])
>>> a[j,k,getters]
array([[  0,   8,  16,  21,  29],
       [ 37,  42,  50,  58,  63],
       [ 71,  79,  84,  92, 100]])

Of course, you can keep k and j around and use them as often as you'd like. As pointed out by DSM in comments below, j,k = np.indices(a.shape[:2]) should also work instead of meshgrid. Which one is faster (apparently) depends on the number of elements you are using.

like image 69
mgilson Avatar answered Dec 03 '25 07:12

mgilson



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!