Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: Reshape array along a specified axis

I have the following array:

x = np.arange(24).reshape((2,3,2,2))
array([[[[ 0,  1],
     [ 2,  3]],

    [[ 4,  5],
     [ 6,  7]],

    [[ 8,  9],
     [10, 11]]],


   [[[12, 13],
     [14, 15]],

    [[16, 17],
     [18, 19]],

    [[20, 21],
     [22, 23]]]])

I would like to reshape it to a (3,4,2) array like below:

array([[[ 0,  1],
    [ 2,  3],
    [12, 13],
    [14, 15]],

   [[ 4,  5],
    [ 6,  7],
    [16, 17],
    [18, 19]],

   [[ 8,  9],
    [10, 11],
    [20, 21],
    [22, 23]]])

I've tried to use reshape but it gave me the following which is not what I want.

array([[[ 0,  1],
    [ 2,  3],
    [ 4,  5],
    [ 6,  7]],

   [[ 8,  9],
    [10, 11],
    [12, 13],
    [14, 15]],

   [[16, 17],
    [18, 19],
    [20, 21],
    [22, 23]]])

Can someone please help?

like image 393
Allen Avatar asked Aug 13 '16 05:08

Allen


Video Answer


3 Answers

Use transpose and then reshape like so -

shp = x.shape
out = x.transpose(1,0,2,3).reshape(shp[1],-1,shp[-1])
like image 172
Divakar Avatar answered Sep 19 '22 19:09

Divakar


x = np.arange(24).reshape((2,3,2,2))
y = np.dstack(zip(x))[0]
print y

result:

[[[ 0  1]
  [ 2  3]
  [12 13]
  [14 15]]

 [[ 4  5]
  [ 6  7]
  [16 17]
  [18 19]]

 [[ 8  9]
  [10 11]
  [20 21]
  [22 23]]]
like image 33
Julien Avatar answered Sep 16 '22 19:09

Julien


You can also use concatenate like so-

out=np.concatenate((x),axis=1)

I will note those since you mentioned this is for performance, this doesn't seem faster than Divakar suggestion:

shp = x.shape
out = x.transpose(1,0,2,3).reshape(shp[1],-1,shp[-1])

If anyone does a bench mark or finds something faster I would love to know.

like image 42
Supamee Avatar answered Sep 18 '22 19:09

Supamee