Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Linear Layer now automatically reshape the input?

I remember in the past, nn.Linear only accepts 2D tensors.

But today, I discover that nn.Linear now accepts 3D, or even tensors with arbitrary dimensions.

X = torch.randn((20,20,20,20,10))
linear_layer = nn.Linear(10,5)
output = linear_layer(X)
print(output.shape)
>>> torch.Size([20, 20, 20, 20, 5])

When I check the documentation for Pytorch, it does say that it now takes

Input: :math:(N, *, H_{in}) where :math:* means any number of additional dimensions and :math:H_{in} = \text{in\_features}

So it seems to me that Pytorch nn.Linear now reshape the input by x.view(-1, input_dim) automatically.

But I cannot find any x.shape or x.view in the source code:

class Linear(Module):
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

Can anyone confirms this?

like image 622
Raven Cheuk Avatar asked Feb 28 '26 11:02

Raven Cheuk


1 Answers

torch.nn.Linear uses torch.nn.functional.linear function under the hood, that's where the operations are taking places (see documentation).

It looks like this (removed docstrings and decorators for brevity):

def linear(input, weight, bias=None):
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

First case is addmm, which implements beta*mat + alpha*(mat1 @ mat2) and is supposedly faster (see here for example).

Second operation is matmul, and as one can read in their docs it performs various operations based on the shape of tensors provided (five cases, not going to copy them blatantly here).

In summary it preserves dimensions between first batch and last features dimension. No view() is used whatsoever, especially not this x.view(-1, input_dim), check the code below:

import torch

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)

print(torch.matmul(tensor1, tensor2).shape)
print(torch.matmul(tensor1, tensor2).view(-1, tensor1.shape[1]).shape)

which gives:

torch.Size([10, 3, 5]) # preserves input's 3
torch.Size([50, 3]) # destroys the batch even
like image 139
Szymon Maszke Avatar answered Mar 02 '26 07:03

Szymon Maszke



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!