Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Easy way to collapse trailing dimensions of numpy array?

In Matlab, I can do the following:

X = randn(25,25,25);
size(X(:,:))

ans = 
    25   625

I often find myself wanting to quickly collapse the trailing dimensions of an array, and do not know how to do this in numpy.

I know I can do this:

In [22]: x = np.random.randn(25,25,25)
In [23]: x = x.reshape(x.shape[:-2] + (-1,))
In [24]: x.shape
Out[24]: (25, 625)

but x.reshape(x.shape[:-2] + (-1,)) is a lot less concise (and requires more information about x) than simply doing x(:,:).

I've obviously tried the analogous numpy indexing, but that does not work as desired:

In [25]: x = np.random.randn(25,25,25)
In [26]: x[:,:].shape
Out[26]: (25, 25, 25)

Any hints on how to collapse the trailing dimensions of an array in a concise manner?

Edit: note that I'm after the resulting array itself, not just its shape. I merely use size() and x.shape in the above examples to indicate what the array is like.

like image 404
EelkeSpaak Avatar asked Jun 11 '15 11:06

EelkeSpaak


1 Answers

What is supposed to happen with a 4d or higher?

octave:7> x=randn(25,25,25,25);
octave:8> size(x(:,:))
ans =
      25   15625

Your (:,:) reduces it to 2 dimensions, combining the last ones. The last dimension is where MATLAB automatically adds and collapses dimensions.

In [605]: x=np.ones((25,25,25,25))

In [606]: x.reshape(x.shape[0],-1).shape  # like Joe's
Out[606]: (25, 15625)

In [607]: x.reshape(x.shape[:-2]+(-1,)).shape
Out[607]: (25, 25, 625)

Your reshape example does something different from MATLAB, it just collapses the last 2. Collapsing it down to 2 dimensions like MATLAB is a simpler expression.

The MATLAB is concise simply because your needs match it's assumptions. The numpy equivalent isn't quite so concise, but gives you more control

For example to keep the last dimension, or combine dimensions 2 by 2:

In [608]: x.reshape(-1,x.shape[-1]).shape
Out[608]: (15625, 25)
In [610]: x.reshape(-1,np.prod(x.shape[-2:])).shape
Out[610]: (625, 625)

What's the equivalent MATLAB?

octave:24> size(reshape(x,[],size(x)(2:end)))
ans =
15625      25
octave:31> size(reshape(x,[],prod(size(x)(3:end))))
like image 185
hpaulj Avatar answered Sep 22 '22 01:09

hpaulj