Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multiple-input multiple-output CNN with custom loss function

I have a set of 2D input arrays m x n namely A,B,C and I have to predict two 2D output arrays namely d,e for which I do have the expected values. You can think of the inputs/outputs as grey images if you like.

Because of the spatial information is relevant (these are actually 2D physical domains) I want to use a Convolutional Neural Network to predict d and e. My design (not tested yet) looks as follows:

enter image description here

Because I have multiple inputs, I guess I should use multiple columns (or branches) to find different features for each of the inputs (they look fairly different). Each of these columns follows a encoding-decoding architecture used in segmentation (see SegNet): Conv2D block involves a convolution+batch normalisation+ReLU layer. Deconv2D involves a deconvolution+batch normalisation+ReLU.

Then, I can merge the output of each column by either concatenating, averaging or taking the maximum for example. To obtain the original m x n shape for each of the outputs I have seen I could do this with a 1 x 1 kernel convolution.

I want to predict the two outputs from that single layer. Is that okay from the network structure point of view? Finally my loss function depends on the outputs themselves compared to the target plus another relation I want to impose.

A would like to have some expert opinion on this since this is my first design of a CNN and I am not sure if I it makes sense as it is now and/or if there are better approaches (or network architectures) to this problem.

I posted this originally in datascience but I did not get much feedback. I am now posting it here since there is a bigger community on these topics plus I would be very grateful to receive implementation tips beside network architectural ones. Thanks.

like image 355
b-fg Avatar asked Dec 19 '18 03:12

b-fg


Video Answer


2 Answers

I think your design makes sense in general:

  • since A, B, and C are fairly different, you make each input a transform sub-network, and then fuse them together, which is your intermediate representation.

  • from the intermediate representation, you apply additional CNN to decode D and E, respectively.

Several things:

  1. A, B, and C looking different does not necessarily mean you can't stack them together as a 3-channel input. The decision should be made upon the fact that whether the values in A, B, and C mean differently or not. For example, if A is a gray scale image, B is a depth map, C is a also a gray image captured by a different camera. Then A and B are better processed in your suggested way, but A and C can be concatenated as one single input before feeding it to your network.

  2. D and E are two outputs of the network and will be trained in the multi-task manner. Of course, they should share some latent feature, and one should split at this feature to apply a down-stream non-shared weight branch for each output. However, where to split is usually tricky.

like image 72
pitfall Avatar answered Oct 11 '22 00:10

pitfall


It is really a broad question, asking for answers relying mostly on opinions. Here are my two cents though, which you might find interesting as it does not go along the previous answers here and on datascience.

First, I wouldn't go with separate columns for each input. AFAIK, when different inputs are processed by different columns, it is almost always the case that the network is some sort of Siemese network and the columns share the same weights; or at least the columns all need to produce a similar code. It is not your case here, so I would simply not bother.

Second, you are blessed with a problem with a dense output and no need to learn a code. This should direct you straight to U-nets, which outperforms any bottleneck-designed network without much effort. U-nets were introduced for dense segmentation but they shine at any dense-output problem really.

In short, just stack your inputs together and use a U-net.

like image 32
P-Gn Avatar answered Oct 11 '22 00:10

P-Gn