Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to take a transpose for each matrix in a batch in Pytorch?

Tags:

pytorch

Say I have a 4 batch of 5x3 matrixes. So the dimensions of these tensor are 4x5x3. How do I take the transpose of each matrix within each batch. So converting it to 4x3x5?

like image 269
SantoshGupta7 Avatar asked Jan 26 '23 22:01

SantoshGupta7


1 Answers

I will drop some benchmarks here for the sake of performance. Using the same tensor proposed in the OP's answer.

In[2]: import torch
In[3]: x = torch.randn(2, 3, 5)
In[4]: x.size()
Out[4]: torch.Size([2, 3, 5])
In[5]: %timeit x.permute(1, 0, 2)
1.03 µs ± 41.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In[6]: %timeit torch.transpose(x, 0, 1)
892 ns ± 9.61 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In[7]: torch.transpose(x, 0, 1).equal(x.permute(1, 0, 2))
Out[7]: True

It is clear that torch.transpose is faster, so It is advised to use it when possible.

like image 158
ndrwnaguib Avatar answered Feb 16 '23 21:02

ndrwnaguib