Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Siamese Neural Network in Pytorch

Tags:

python

pytorch

How can I implement a siamese neural network in PyTorch?

What is a siamese neural network? A siamese neural network consists in two identical neural networks, each one taking one input. Identical means that the two neural networks have the exact same architecture and share the same weights.

enter image description here

like image 891
BiBi Avatar asked Mar 04 '23 21:03

BiBi


1 Answers

Implementing siamese neural networks in PyTorch is as simple as calling the network function twice on different inputs.

mynet = torch.nn.Sequential(
        nn.Linear(10, 512),
        nn.ReLU(),
        nn.Linear(512, 2))
...
output1 = mynet(input1)
output2 = mynet(input2)
...
loss.backward()

When invoking loss.backwad(), PyTorch will automatically sum the gradients coming from the two invocations of mynet.

You can find a full-fledged example here.

like image 180
BiBi Avatar answered Mar 16 '23 05:03

BiBi