Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - shape of nn.Linear weights

Yesterday I came across this question and for the first time noticed that the weights of the linear layer nn.Linear need to be transposed before applying matmul.

Code for applying the weights:

output = input.matmul(weight.t())


What is the reason for this?

Why are the weights not in the transposed shape just from the beginning, so they don't need to be transposed every time before applying the layer?

like image 992
MBT Avatar asked Oct 16 '22 10:10

MBT


1 Answers

I found an answer here: Efficient forward pass in nn.Linear #2159

It seems like there is no real reasoning behind this. However the transpose operation doesn't seem to be slowing down the computation.

According to the issue mentioned above, during the forward pass the transpose operation is (almost) free in terms of computation. While during the backward pass leaving out the transpose operation would actually make computation less efficient with the current implementation.

The last post in that issue sums it up quite nicely:

It's historical weight layout, changing it is backward-incompatible. Unless there is some BIG benefit in terms of speed or convenience, we wont break userland.

https://github.com/pytorch/pytorch/issues/2159#issuecomment-390068272

like image 128
MBT Avatar answered Oct 21 '22 03:10

MBT