Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Better way to concatenate ConvLSTM2D model and Tabular model

I have built a model that takes 3 images of a time series along with 5 numerical information as input and produces the next three images of the time series. I accomplished this by:

  1. Build a ConvLSTM2D model for processing the images (pretty similar to the example listed on Keras documentation here). Input size=(3x128x128x3)
  2. Build a simple model for tabular data with a few Dense layers. Input size=(1,5)
  3. Concatenate these two models
  4. Have a Conv3D model that produces the next 3 images

The LSTM models produces output of size 393216 (3x128x128x8). Now I had to set the output of tabular model to 49,152 so that I can have the input size of 442368 (3x128x128x9) in the next layer. So this unnecessary inflation of tabular model's Dense layer makes the otherwise efficient LSTM model perform awfully.

Is there a better way to concatenate the two models? Is there a way I can just have an output of 10 in the tabular model's Dense layer?

The model:

x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
# x = MaxPooling3D()(x)

x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(49152, activation="relu")(x_tab)
x_tab = Flatten()(x_tab)

concat = Concatenate()([x, x_tab])

output = Reshape((3,128,128,9))(concat)
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
model = Model([x_input, x_tab_input], output)
model.compile(loss='mae', optimizer='rmsprop')

Model Summary:

Model: "functional_3"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
======================================================================================================================================================
input_4 (InputLayer)                             [(None, None, 128, 128, 3)]      0                                                                   
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_9 (ConvLSTM2D)                      (None, None, 128, 128, 32)       40448             input_4[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization)       (None, None, 128, 128, 32)       128               conv_lst_m2d_9[0][0]                              
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_10 (ConvLSTM2D)                     (None, None, 128, 128, 16)       27712             batch_normalization_9[0][0]                       
______________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization)      (None, None, 128, 128, 16)       64                conv_lst_m2d_10[0][0]                             
______________________________________________________________________________________________________________________________________________________
input_5 (InputLayer)                             [(None, 5)]                      0                                                                   
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_11 (ConvLSTM2D)                     (None, None, 128, 128, 8)        6944              batch_normalization_10[0][0]                      
______________________________________________________________________________________________________________________________________________________
dense (Dense)                                    (None, 100)                      600               input_5[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization)      (None, None, 128, 128, 8)        32                conv_lst_m2d_11[0][0]                             
______________________________________________________________________________________________________________________________________________________
dense_1 (Dense)                                  (None, 49152)                    4964352           dense[0][0]                                       
______________________________________________________________________________________________________________________________________________________
flatten_3 (Flatten)                              (None, None)                     0                 batch_normalization_11[0][0]                      
______________________________________________________________________________________________________________________________________________________
flatten_4 (Flatten)                              (None, 49152)                    0                 dense_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
concatenate (Concatenate)                        (None, None)                     0                 flatten_3[0][0]                                   
                                                                                                    flatten_4[0][0]                                   
______________________________________________________________________________________________________________________________________________________
reshape_2 (Reshape)                              (None, 3, 128, 128, 9)           0                 concatenate[0][0]                                 
______________________________________________________________________________________________________________________________________________________
conv3d_2 (Conv3D)                                (None, 3, 128, 128, 3)           732               reshape_2[0][0]                                   
======================================================================================================================================================
Total params: 5,041,012
Trainable params: 5,040,900
Non-trainable params: 112
______________________________________________________________________________________________________________________________________________________
like image 970
vettipayyan Avatar asked Jan 30 '21 01:01

vettipayyan


Video Answer


1 Answers

I agree with you that the huge Dense layer (which has millions of parameters) might hinder the performance of the model. Instead of inflating the tabular data with a Dense layer, you could rather choose one of the following two approaches.


Option 1: Tile the x_tab tensor so that it matches your desired shape. This can be achieved with the following steps:

First, there is no need to flatten the ConvLSTM2D's encoded tensor:

x_input = Input(shape=(3, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)  # Shape=(None, None, 128, 128, 8) 
# Commented: x = Flatten()(x)

Second, you can process your tabular data with one or several Dense layers. For example:

dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
# x_tab = Flatten()(x_tab)  # Note: Flattening a 2D tensor leaves the tensor unchanged

Third, we wrap the tensorflow operation tf.tile in a Lambda layer, effectively creating copies of the tensor x_tab so that it matches the desired shape:

def repeat_tabular(x_tab):
    h = x_tab[:, None, None, None, :]  # Shape=(bs, 1, 1, 1, dim)
    h = tf.tile(h, [1, 3, 128, 128, 1])  # Shape=(bs, 3, 128, 128, dim)
    return h
x_tab = Lambda(repeat_tabular)(x_tab)

Finally, we concatenate the x and the tiled x_tab tensors along the last axis (you might also consider concatenating along the first axis, corresponding to the channels' dimension)

concat = Concatenate(axis=-1)([x, x_tab])  # Shape=(3,128,128,8+dim)
output = concat
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
# ...

Note that this solution might be a bit naive in the sense that the model is not encoding the input sequence of images into a low-dimensional representation, limiting the receptive field of the network and potentially resulting in degraded performance.


Option 2: Similar to autoencoders and U-Net, it might be desirable to encode your sequence of images into a low-dimensional representation in order to discard the unwanted variation (e.g. noise) while preserving the meaningful signal (e.g. required to infer the next 3 images of the sequence). This can be achieved as follows:

First, encode the input sequence of images into a low-dimension 2-dimensional tensor. For example, something along the lines of:

x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)

Note that the last ConvLSTM2D is not returning the sequences. You might want to explore different encoders to arrive at this point (e.g. you could also use pooling layers here).

Second, process your tabular data with Dense layers. For example:

dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)

Third, concatenate the data from the previous two streams:

concat = Concatenate(axis=-1)([x, x_tab])

Fourth, use a Dense + Reshape layer to project the concatenated vectors into a sequence of low-resolution images:

h = Dense(3 * 32 * 32 * 3)(concat)
output = Reshape((3, 32, 32, 3))(h)

The shape of output allows to up-sample the images into a shape of (128, 128, 3), but it is otherwise arbitrary (e.g. you might also want to experiment here).

Finally, apply one or several Conv3DTranspose layers to get to the desired output (e.g. 3 images of shape (128, 128, 3)).

output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),
                                         strides=(1, 2, 2), padding='same',
                                         activation='relu')(output)
output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),
                                         strides=(1, 2, 2), padding='same',
                                         activation='relu')(output)  # Shape=(None, 3, 128, 128, 3)

The rationale behind transposed convolution layers is discussed here. Essentially, the Conv3DTranspose layer goes in the opposite direction of normal convolutions - it allows upsampling your low-resolution images into high-resolution images.

like image 61
rvinas Avatar answered Oct 20 '22 00:10

rvinas