Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding Bilinear Layers

When having a bilinear layer in PyTorch I can't wrap my head around how the calculation is done.

Here is a small example where I tried to figure out how it works:

In:

import torch.nn as nn
B = nn.Bilinear(2, 2, 1)
print(B.weight)

Out:

Parameter containing:
tensor([[[-0.4394, -0.4920],
         [ 0.6137,  0.4174]]], requires_grad=True)

I am putting through a zero-vector and a one-vector.

In:

print(B(torch.ones(2), torch.zeros(2)))
print(B(torch.zeros(2), torch.ones(2)))

Out:

tensor([0.2175], grad_fn=<ThAddBackward>)
tensor([0.2175], grad_fn=<ThAddBackward>)

I tried adding up the weights in various ways but I'm not getting the same result.

Thanks in advance!

like image 245
MBT Avatar asked Aug 10 '18 08:08

MBT


1 Answers

The operation done by nn.Bilinear is B(x1, x2) = x1*A*x2 + b (c.f. doc) with:

  • A stored in nn.Bilinear.weight
  • b stored in nn.Bilinear.bias

If you take into account the (optional) bias, you should obtain the expected results.


import torch
import torch.nn as nn

def manual_bilinear(x1, x2, A, b):
    return torch.mm(x1, torch.mm(A, x2)) + b

x_ones = torch.ones(2)
x_zeros = torch.zeros(2)

# ---------------------------
# With Bias:

B = nn.Bilinear(2, 2, 1)
A = B.weight
print(B.bias)
# > tensor([-0.6748], requires_grad=True)
b = B.bias

print(B(x_ones, x_zeros))
# > tensor([-0.6748], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_zeros.view(2, 1), A.squeeze(), b))
# > tensor([[-0.6748]], grad_fn=<ThAddBackward>)

print(B(x_ones, x_ones))
# > tensor([-1.7684], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_ones.view(2, 1), A.squeeze(), b))
# > tensor([[-1.7684]], grad_fn=<ThAddBackward>)

# ---------------------------
# Without Bias:

B = nn.Bilinear(2, 2, 1, bias=False)
A = B.weight
print(B.bias)
# None
b = torch.zeros(1)

print(B(x_ones, x_zeros))
# > tensor([0.], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_zeros.view(2, 1), A.squeeze(), b))
# > tensor([0.], grad_fn=<ThAddBackward>)

print(B(x_ones, x_ones))
# > tensor([-0.7897], grad_fn=<ThAddBackward>)
print(manual_bilinear(x_ones.view(1, 2), x_ones.view(2, 1), A.squeeze(), b))
# > tensor([[-0.7897]], grad_fn=<ThAddBackward>)
like image 120
benjaminplanche Avatar answered Oct 02 '22 16:10

benjaminplanche