Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to efficiently unroll a matrix by value with numpy?

I have a matrix M with values 0 through N within it. I'd like to unroll this matrix to create a new matrix A where each submatrix A[i, :, :] represents whether or not M == i.

The solution below uses a loop.

# Example Setup
import numpy as np

np.random.seed(0)
N = 5
M = np.random.randint(0, N, size=(5,5))

# Solution with Loop
A = np.zeros((N, M.shape[0], M.shape[1]))
for i in range(N):
    A[i, :, :] = M == i

This yields:

M
array([[4, 0, 3, 3, 3],
       [1, 3, 2, 4, 0],
       [0, 4, 2, 1, 0],
       [1, 1, 0, 1, 4],
       [3, 0, 3, 0, 2]])

M.shape
# (5, 5)


A 
array([[[0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [1, 0, 0, 0, 1],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 1, 0]],
       ...
       [[1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0]]])

A.shape
# (5, 5, 5)

Is there a faster way, or a way to do it in a single numpy operation?

like image 877
seveibar Avatar asked Apr 05 '19 22:04

seveibar


People also ask

How do I flatten a matrix in NumPy?

By using ndarray. flatten() function we can flatten a matrix to one dimension in python. order:'C' means to flatten in row-major. 'F' means to flatten in column-major.

What is array manipulation in NumPy?

Data manipulation in Python is nearly synonymous with NumPy array manipulation: even newer tools like Pandas (Chapter 3) are built around the NumPy array. This section will present several examples of using NumPy array manipulation to access data and subarrays, and to split, reshape, and join the arrays.

How do I roll an array in NumPy?

The numpy. roll() function rolls array elements along the specified axis. Basically what happens is that elements of the input array are being shifted. If an element is being rolled first to the last position, it is rolled back to the first position.

Which of the following Python library is used for working with arrays more efficiently?

NumPy arrays are very heavily used in the data science world to work with multidimensional arrays. They are more efficient than the array module and Python lists in general. Reading and writing elements in a NumPy array is faster, and they support "vectorized" operations such as elementwise addition.


2 Answers

Broadcasted comparison is your friend:

B = (M[None, :] == np.arange(N)[:, None, None]).view(np.int8)

 np.array_equal(A, B)
# True

The idea is to expand the dimensions in such a way that the comparison can be broadcasted in the manner desired.


As pointed out by @Alex Riley in the comments, you can use np.equal.outer to avoid having to do the indexing stuff yourself,

B = np.equal.outer(np.arange(N), M).view(np.int8)

np.array_equal(A, B)
# True
like image 129
cs95 Avatar answered Nov 15 '22 07:11

cs95


You can make use of some broadcasting here:

P = np.arange(N)
Y = np.broadcast_to(P[:, None], M.shape)
T = np.equal(M, Y[:, None]).astype(int)

Alternative using indices:

X, Y = np.indices(M.shape)
Z = np.equal(M, X[:, None]).astype(int)
like image 42
user3483203 Avatar answered Nov 15 '22 08:11

user3483203