Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to determine a numpy-array reshape strategy

For a python project I often find myself reshaping and re-arranging n-dimensional numpy arrays. However, I have a hard time to determine how to approach the problem, visualize the outcome of the results of the reshaping methods and knowing my solution is efficient.

At the moment when confronted with such a problem my strategy is to start ipython, load some sample data and go trial and error until I find a combination of transpose()s, reshape()s and swapaxes()s. which gets the desired result. It gets the job done, but without a real understanding of what is going on and often produces code which is hard to maintain.

So, my question is about finding a strategy. How do you approach such a problem? How do you visualize an ndarray in your head when you have to shape it in the desired format? How do you come to the right actions?

To make answering a bit more concrete, an example to play with:

Assume you want to reshape the following 3d-array

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])

to a 2d-array where the first columns from the 3rd dimension are placed first, the 2nd columns second, ....etc

The result should look like this:

array([[ 0,  9, 18,  3, 12, 21,  6, 15, 24],
       [ 1, 10, 19,  4, 13, 22,  7, 16, 25],
       [ 2, 11, 20,  5, 14, 23,  8, 17, 26]])

PS. also any reading material on the subject would be great!

like image 584
joepjp Avatar asked May 28 '15 17:05

joepjp


People also ask

Why do we do reshape (- 1 1?

If you have an array of shape (2,4) then reshaping it with (-1, 1), then the array will get reshaped in such a way that the resulting array has only 1 column and this is only possible by having 8 rows, hence, (8,1).

Is there reshape method in NumPy?

The numpy. reshape() function allows us to reshape an array in Python. Reshaping basically means, changing the shape of an array. And the shape of an array is determined by the number of elements in each dimension.

Which NumPy method is used to change the shape of array?

Reshaping means changing the shape of an array. The shape of an array is the number of elements in each dimension. By reshaping we can add or remove dimensions or change number of elements in each dimension.


1 Answers

I regularly play about with shapes in ipython. However, to make things clearer, I start with array with distinct dimensions.

arr = np.arange(3*4*5).reshape(3,4,5)

That way, it's easier to identify how the axes get shifted, for example:

In [25]: arr.shape
Out[25]: (3, 4, 5)

In [26]: arr.T.shape
Out[26]: (5, 4, 3)

In [31]: arr.T.reshape(5,-1)
Out[31]: 
array([[ 0, 20, 40,  5, 25, 45, 10, 30, 50, 15, 35, 55],
       [ 1, 21, 41,  6, 26, 46, 11, 31, 51, 16, 36, 56],
       [ 2, 22, 42,  7, 27, 47, 12, 32, 52, 17, 37, 57],
       [ 3, 23, 43,  8, 28, 48, 13, 33, 53, 18, 38, 58],
       [ 4, 24, 44,  9, 29, 49, 14, 34, 54, 19, 39, 59]])

where as a different transpose (that does not switch the order of 3,4)

In [38]: np.transpose(arr,[2,0,1]).shape
Out[38]: (5, 3, 4)

In [39]: np.transpose(arr,[2,0,1]).reshape(5,-1)
Out[39]: 
array([[ 0,  5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55],
       [ 1,  6, 11, 16, 21, 26, 31, 36, 41, 46, 51, 56],
       [ 2,  7, 12, 17, 22, 27, 32, 37, 42, 47, 52, 57],
       [ 3,  8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58],
       [ 4,  9, 14, 19, 24, 29, 34, 39, 44, 49, 54, 59]])

I also like to use 'oddly' shaped arrays like this when developing functions. That way, if I do mess up some transpose or broadcasting, dimensions errors will jump out at me. Experience tells me that once I get the dimensions right, the values will also be correct. Or at least the class of errors that affect values is quite different from those that affect dimensions.

I also liberally sprinkle development code with print arr.shape like statements, or even assert x.shape==y.shape assertions.

It also helps to label dimensions:

M, N, L = 3, 4, 5
np.empty((M,N,L))

or like in einsum

np.einsum('ijk,kj->i', A, B) # if A is (M,N,L), B must be (L,N)

https://stackoverflow.com/a/29903842/901925 is an example of trying to understand and explain rollaxis.

Another strategy is to look at the Python code of numpy functions. Often they accept axis arguments. It's instructive to see how they use those. Sometimes that particular axis is rotated to the front, or to the end. Sometimes a nd array is reshaped into a 2d array, collapsing all axes except one down to one. Other achieve generality by constructing and manipulating an indexing tuple. More advanced functions play with the strides as well as the shape.

Whether a dimension should be first or last is usally an optimization issue - and may involve tradeoffs between ease of use (broadcasting, indexing) and speed. Just keep in mind that for "C" order, the last dimension forms contiguous blocks.

like image 102
hpaulj Avatar answered Sep 27 '22 16:09

hpaulj