Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Iterating over first d axes of numpy array

I'm given an array with an arbitrary number of axes, and I want to iterate over, say the first 'd' of them. How do I do this?

Initially I thought I would make an array containing all the indices I want to loop through, using

i = np.indices(a.shape[:d])
indices = np.transpose(np.asarray([x.flatten() for x in i]))
for idx in indices:
    a[idx]

But apparently I cannot index an array like that, i.e. using another array containing the index.

like image 527
Mark van der Wilk Avatar asked Aug 11 '14 17:08

Mark van der Wilk


2 Answers

You can use ndindex:

d = 2
a = np.random.random((2,3,4))
for i in np.ndindex(a.shape[:d]):
    print i, a[i]

Output:

(0, 0) [ 0.72730488  0.2349532   0.36569509  0.31244037]
(0, 1) [ 0.41738425  0.95999499  0.63935274  0.9403284 ]
(0, 2) [ 0.90690468  0.03741634  0.33483221  0.61093582]
(1, 0) [ 0.06716122  0.52632369  0.34441657  0.80678942]
(1, 1) [ 0.8612884   0.22792671  0.15628046  0.63269415]
(1, 2) [ 0.17770685  0.47955698  0.69038541  0.04838387]
like image 129
shx2 Avatar answered Sep 28 '22 00:09

shx2


You could reshape a to compress the 1st d dimensions into one:

for x in a.reshape(-1,*a.shape[d:]):
    print x

or

aa=a.reshape(-1,*a.shape[d:])
for i in range(aa.shape[0]):
    print aa[i]

We really need to know more about what you need to do with a[i].


shx2 uses np.ndenumerate. The doc for that function mentions ndindex. That could be used as:

for i in np.ndindex(a.shape[:d]):
    print i
    print a[i]

Where i is a tuple. It's instructive to look at the Python code for these functions. ndindex for example uses nditer.

like image 26
hpaulj Avatar answered Sep 28 '22 00:09

hpaulj