Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to construct a network with two inputs in PyTorch

Suppose I want to have the general neural network architecture:

Input1 --> CNNLayer 
                    \
                     ---> FCLayer ---> Output
                    /
Input2 --> FCLayer

Input1 is image data, input2 is non-image data. I have implemented this architecture in Tensorflow.

All pytorch examples I have found are one input go through each layer. How can I define forward func to process 2 inputs separately then combine them in a middle layer?

like image 342
LeonG Avatar asked Aug 06 '18 04:08

LeonG


People also ask

How do I combine two layers in PyTorch?

We can join tensors in PyTorch using torch.cat() and torch. stack() functions. Both the function help us to join the tensors but torch.cat() is basically used to concatenate the given sequence of tensors in the given dimension.


1 Answers

By "combine them" I assume you mean to concatenate the two inputs.
Assuming you concat along the second dimension:

import torch
from torch import nn

class TwoInputsNet(nn.Module):
  def __init__(self):
    super(TwoInputsNet, self).__init__()
    self.conv = nn.Conv2d( ... )  # set up your layer here
    self.fc1 = nn.Linear( ... )  # set up first FC layer
    self.fc2 = nn.Linear( ... )  # set up the other FC layer

  def forward(self, input1, input2):
    c = self.conv(input1)
    f = self.fc1(input2)
    # now we can reshape `c` and `f` to 2D and concat them
    combined = torch.cat((c.view(c.size(0), -1),
                          f.view(f.size(0), -1)), dim=1)
    out = self.fc2(combined)
    return out

Note that when you define the number of inputs to self.fc2 you need to take into account both out_channels of self.conv as well as the output spatial dimensions of c.

like image 154
Shai Avatar answered Oct 22 '22 18:10

Shai