Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Merge axis in numpy array

I want to convert X,Y,Z numpy array to (X*Z)*Y numpy array.

Code(Slow):

 def rearrange(data):
        samples,channels,t_insts=data.shape
        append_data=np.empty([0,channels])
        for sample in range(0,samples):
            for t_inst in range(0,t_insts):
                channel_data=data[sample,:,t_inst]
                append_data=np.vstack((append_data,channel_data))
        return append_data.shape

I am looking for a better vectorized approach if possible

like image 373
Abhishek Bhatia Avatar asked Jan 07 '23 19:01

Abhishek Bhatia


1 Answers

You can use np.transpose to swap rows with columns and then reshape -

data.transpose(0,2,1).reshape(-1,data.shape[1])

Or use np.swapaxes to do the swapping of rows and columns and then reshape -

data.swapaxes(1,2).reshape(-1,data.shape[1])
like image 192
Divakar Avatar answered Jan 14 '23 22:01

Divakar