Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Make every n-th slice of a 3-d numpy array consecutive

Statement

Let's say we have some 3-dimesional numpy array, A, which has shape (X, Y, Z). I want to create a new array B, which will also have shape (X, Y, Z).

We desire that the first n slices (:n) of B along the zero-th axis correspond to every m-th slice (::m) of A along the zero-th axis.

We also desire that slices n:2*n of B correspond to every m+1-th slice (1::m ) of A. And so on for the rest of the array.

What is the best way to achieve this using vectorized numpy computations?

Example

The above statement is best understood by an example. So let's start by setting up some example array A:

import numpy as np

# Create array A with shape (15, 3, 3)
n = 3; m = 5
a = np.array([i * np.eye(3) for i in range(1, 1+m)])
A = np.tile(a, (n, 1, 1))

If we take a look at some of the zero-th slices of A we have that:

print(A[0])
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
print(A[1])
[[2. 0. 0.]
 [0. 2. 0.]
 [0. 0. 2.]]

...

print(A[4])
[[5. 0. 0.]
 [0. 5. 0.]
 [0. 0. 5.]]
print(A[5])
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

and so on.

The values in A are not important, but should help illustrate the original statement.

I would like to know if we can create matrix B using numpy functions only. Array B should have slices:

print(B[0])
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
print(B[1])
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
print(B[2])
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
print(B[3])
[[2. 0. 0.]
 [0. 2. 0.]
 [0. 0. 2.]]

and so on.

Is there any way to generate B from A with a pure numpy solution?

What I've tried

The following gives B as desired, but it becomes tedious as m gets large:

# vstack solution
B = np.vstack((A[::m], A[1::m], A[2::m], A[3::m], A[4::m]))

Using a list comprehension also works, but I'd like to avoid using loops:

# List comprehension solution
B = np.vstack([A[i::m] for i in range(m)])
like image 718
jwalton Avatar asked Feb 05 '26 08:02

jwalton


1 Answers

I think this does what you want:

import numpy as np

# Create array A with shape (15, 3, 3)
a = np.array([i * np.eye(3) for i in range(1, 6)])
A = np.tile(a, (3, 1, 1))

B = np.swapaxes(A.reshape(3, 5, 3, 3), 0, 1)
B = B.reshape(-1, 3, 3)
print(B)
# [[[1. 0. 0.]
#   [0. 1. 0.]
#   [0. 0. 1.]]
#
#  [[1. 0. 0.]
#   [0. 1. 0.]
#   [0. 0. 1.]]
#
#  [[1. 0. 0.]
#   [0. 1. 0.]
#   [0. 0. 1.]]
#
#  [[2. 0. 0.]
#   [0. 2. 0.]
#   [0. 0. 2.]]
# ...
like image 71
jdehesa Avatar answered Feb 06 '26 21:02

jdehesa