Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In numpy, how to efficiently list all fixed-size submatrices?

I have an arbitrary NxM matrix, for example:

1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4

I want to get a list of all 3x3 submatrices in this matrix:

1 2 3       2 3 4               0 1 2
7 8 9   ;   8 9 0   ;  ...  ;   6 7 8
3 4 5       4 5 6               2 3 4

I can do this with two nested loops:

rows, cols = input_matrix.shape
patches = []
for row in np.arange(0, rows - 3):
    for col in np.arange(0, cols - 3):
        patches.append(input_matrix[row:row+3, col:col+3])

But for a large input matrix, this is slow. Is there a way to do this faster with numpy?

I've looked at np.split, but that gives me non-overlapping sub-matrices, whereas I want all possible submatrices, regardless of overlap.

like image 504
cberzan Avatar asked Oct 16 '13 21:10

cberzan


People also ask

Does NumPy array have fixed size?

NumPy arrays have a fixed size at creation, unlike Python lists (which can grow dynamically). Changing the size of an ndarray will create a new array and delete the original. The elements in a NumPy array are all required to be of the same data type, and thus will be the same size in memory.

Why would you use NumPy arrays instead of lists in Python?

NumPy uses much less memory to store data The NumPy arrays takes significantly less amount of memory as compared to python lists. It also provides a mechanism of specifying the data types of the contents, which allows further optimisation of the code.

How to create 2 dimensional array in Python using NumPy?

Creating a Two-dimensional Array If you only use the arange function, it will output a one-dimensional array. To make it a two-dimensional array, chain its output with the reshape function. First, 20 integers will be created and then it will convert the array into a two-dimensional array with 4 rows and 5 columns.

What is a 1d NumPy array?

One dimensional array contains elements only in one dimension. In other words, the shape of the NumPy array should contain only one value in the tuple.


1 Answers

You want a windowed view:

from numpy.lib.stride_tricks import as_strided

arr = np.arange(1, 25).reshape(4, 6) % 10
sub_shape = (3, 3)
view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape
arr_view = as_strided(arr, view_shape, arr.strides * 2
arr_view = arr_view.reshape((-1,) + sub_shape)

>>> arr_view
array([[[[1, 2, 3],
         [7, 8, 9],
         [3, 4, 5]],

        [[2, 3, 4],
         [8, 9, 0],
         [4, 5, 6]],

        ...

        [[9, 0, 1],
         [5, 6, 7],
         [1, 2, 3]],

        [[0, 1, 2],
         [6, 7, 8],
         [2, 3, 4]]]])

The good part of doing it like this is that you are not copying any data, you are simply accessing the data of your original array in a different way. For large arrays this can result in tremendous memory savings.

like image 66
Jaime Avatar answered Nov 08 '22 05:11

Jaime