Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

iterate over slices of an ndarray

Say I have a 3D numpy.array, e.g. with dimensions x y z, is there a way to iterate over slices along a particular axis? Something like:

for layer in data.slices(dim=2):
    # do something with layer

Edit: To clarify, the example is a dim=3 array, i.e. shape=(len_x, len_y, len_z). Elazar and equivalently kamjagin's solutions work, but aren't that general - you have to construct the [:, :, i] by hand, which means you need to know the dimensions, and the code isn't general enough to handle arrays of arbitrary dimensions. You can fill missing dimension by using something like [..., :], but again you still have to construct this yourself.

Sorry, should have been clearer, the example was a bit too simple!

like image 340
lost Avatar asked Jun 27 '13 22:06

lost


People also ask

Are NumPy arrays iterable?

Arrays support the iterator protocol and can be iterated over like Python lists.

Is Ndarray and array the same?

The array object in NumPy is called ndarray . We can create a NumPy ndarray object by using the array() function.


1 Answers

Iterating over the first dimension is very easy, see below. To iterate over the others, roll that dimension to the front and do the same:

>>> data = np.arange(24).reshape(2, 3, 4)
>>> for dim_0_slice in data: # the first dimension is easy
...     print dim_0_slice
... 
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
[[12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]]
>>> for dim_1_slice in np.rollaxis(data, 1): # for the others, roll it to the front
...     print dim_1_slice
... 
[[ 0  1  2  3]
 [12 13 14 15]]
[[ 4  5  6  7]
 [16 17 18 19]]
[[ 8  9 10 11]
 [20 21 22 23]]
>>> for dim_2_slice in np.rollaxis(data, 2):
...     print dim_2_slice
... 
[[ 0  4  8]
 [12 16 20]]
[[ 1  5  9]
 [13 17 21]]
[[ 2  6 10]
 [14 18 22]]
[[ 3  7 11]
 [15 19 23]]

EDIT Some timings, to compare different methods for largish arrays:

In [7]: a = np.arange(200*100*300).reshape(200, 100, 300)

In [8]: %timeit for j in xrange(100): a[:, j]
10000 loops, best of 3: 60.2 us per loop

In [9]: %timeit for j in xrange(100): a[:, j, :]
10000 loops, best of 3: 82.8 us per loop

In [10]: %timeit for j in np.rollaxis(a, 1): j
10000 loops, best of 3: 28.2 us per loop

In [11]: %timeit for j in np.swapaxes(a, 0, 1): j
10000 loops, best of 3: 26.7 us per loop
like image 158
Jaime Avatar answered Sep 29 '22 04:09

Jaime