Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Construct array by sampling over every n'th element along last axis

Let a be some (not necessarily one-dimensional) NumPy array with n * m elements along its last axis. I wish to "split" this array along its last axis so that I take every n'th element starting from 0 up until n.

To be explicit let a have shape (k, n * m) then I wish to construct the array of shape (n, k, m)

np.array([a[:, i::n] for i in range(n)])

my problem is that though this indeed return the array that I seek, I still feel that there might be a more efficient and neat NumPy routine for this.

Cheers!


1 Answers

I think this does what you want, without loops. I tested for 2D inputs, it may need some adjustments for more dimensions.

indexes = np.arange(0, a.size*n, n) + np.repeat(np.arange(n), a.size/n)
np.take(a, indexes, mode='wrap').reshape(n, a.shape[0], -1)

In my testing it is a bit slower than your original list solution.

like image 183
John Zwinck Avatar answered Mar 29 '26 03:03

John Zwinck



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!