I have built a model that takes 3 images of a time series along with 5 numerical information as input and produces the next three images of the time series. I accomplished this by:
The LSTM models produces output of size 393216 (3x128x128x8). Now I had to set the output of tabular model to 49,152 so that I can have the input size of 442368 (3x128x128x9) in the next layer. So this unnecessary inflation of tabular model's Dense layer makes the otherwise efficient LSTM model perform awfully.
Is there a better way to concatenate the two models? Is there a way I can just have an output of 10 in the tabular model's Dense layer?
The model:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
# x = MaxPooling3D()(x)
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(49152, activation="relu")(x_tab)
x_tab = Flatten()(x_tab)
concat = Concatenate()([x, x_tab])
output = Reshape((3,128,128,9))(concat)
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
model = Model([x_input, x_tab_input], output)
model.compile(loss='mae', optimizer='rmsprop')
Model Summary:
Model: "functional_3"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
======================================================================================================================================================
input_4 (InputLayer)                             [(None, None, 128, 128, 3)]      0                                                                   
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_9 (ConvLSTM2D)                      (None, None, 128, 128, 32)       40448             input_4[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization)       (None, None, 128, 128, 32)       128               conv_lst_m2d_9[0][0]                              
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_10 (ConvLSTM2D)                     (None, None, 128, 128, 16)       27712             batch_normalization_9[0][0]                       
______________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization)      (None, None, 128, 128, 16)       64                conv_lst_m2d_10[0][0]                             
______________________________________________________________________________________________________________________________________________________
input_5 (InputLayer)                             [(None, 5)]                      0                                                                   
______________________________________________________________________________________________________________________________________________________
conv_lst_m2d_11 (ConvLSTM2D)                     (None, None, 128, 128, 8)        6944              batch_normalization_10[0][0]                      
______________________________________________________________________________________________________________________________________________________
dense (Dense)                                    (None, 100)                      600               input_5[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization)      (None, None, 128, 128, 8)        32                conv_lst_m2d_11[0][0]                             
______________________________________________________________________________________________________________________________________________________
dense_1 (Dense)                                  (None, 49152)                    4964352           dense[0][0]                                       
______________________________________________________________________________________________________________________________________________________
flatten_3 (Flatten)                              (None, None)                     0                 batch_normalization_11[0][0]                      
______________________________________________________________________________________________________________________________________________________
flatten_4 (Flatten)                              (None, 49152)                    0                 dense_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
concatenate (Concatenate)                        (None, None)                     0                 flatten_3[0][0]                                   
                                                                                                    flatten_4[0][0]                                   
______________________________________________________________________________________________________________________________________________________
reshape_2 (Reshape)                              (None, 3, 128, 128, 9)           0                 concatenate[0][0]                                 
______________________________________________________________________________________________________________________________________________________
conv3d_2 (Conv3D)                                (None, 3, 128, 128, 3)           732               reshape_2[0][0]                                   
======================================================================================================================================================
Total params: 5,041,012
Trainable params: 5,040,900
Non-trainable params: 112
______________________________________________________________________________________________________________________________________________________
I agree with you that the huge Dense layer (which has millions of parameters) might hinder the performance of the model. Instead of inflating the tabular data with a Dense layer, you could rather choose one of the following two approaches.
Option 1: Tile the x_tab tensor so that it matches your desired shape. This can be achieved with the following steps:
First, there is no need to flatten the ConvLSTM2D's encoded tensor:
x_input = Input(shape=(3, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)  # Shape=(None, None, 128, 128, 8) 
# Commented: x = Flatten()(x)
Second, you can process your tabular data with one or several Dense layers. For example:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
# x_tab = Flatten()(x_tab)  # Note: Flattening a 2D tensor leaves the tensor unchanged
Third, we wrap the tensorflow operation tf.tile in a Lambda layer, effectively creating copies of the tensor x_tab so that it matches the desired shape:
def repeat_tabular(x_tab):
    h = x_tab[:, None, None, None, :]  # Shape=(bs, 1, 1, 1, dim)
    h = tf.tile(h, [1, 3, 128, 128, 1])  # Shape=(bs, 3, 128, 128, dim)
    return h
x_tab = Lambda(repeat_tabular)(x_tab)
Finally, we concatenate the x and the tiled x_tab tensors along the last axis (you might also consider concatenating along the first axis, corresponding to the channels' dimension)
concat = Concatenate(axis=-1)([x, x_tab])  # Shape=(3,128,128,8+dim)
output = concat
output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
# ...
Note that this solution might be a bit naive in the sense that the model is not encoding the input sequence of images into a low-dimensional representation, limiting the receptive field of the network and potentially resulting in degraded performance.
Option 2: Similar to autoencoders and U-Net, it might be desirable to encode your sequence of images into a low-dimensional representation in order to discard the unwanted variation (e.g. noise) while preserving the meaningful signal (e.g. required to infer the next 3 images of the sequence). This can be achieved as follows:
First, encode the input sequence of images into a low-dimension 2-dimensional tensor. For example, something along the lines of:
x_input = Input(shape=(None, 128, 128, 3))
x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
x = BatchNormalization()(x)
x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
x = BatchNormalization()(x)
x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
Note that the last ConvLSTM2D is not returning the sequences. You might want to explore different encoders to arrive at this point (e.g. you could also use pooling layers here).
Second, process your tabular data with Dense layers. For example:
dim = 10
x_tab_input = Input(shape=(5))
x_tab = Dense(100, activation="relu")(x_tab_input)
x_tab = Dense(dim, activation="relu")(x_tab)
Third, concatenate the data from the previous two streams:
concat = Concatenate(axis=-1)([x, x_tab])
Fourth, use a Dense + Reshape layer to project the concatenated vectors into a sequence of low-resolution images:
h = Dense(3 * 32 * 32 * 3)(concat)
output = Reshape((3, 32, 32, 3))(h)
The shape of output allows to up-sample the images into a shape of (128, 128, 3), but it is otherwise arbitrary (e.g. you might also want to experiment here).
Finally, apply one or several Conv3DTranspose layers to get to the desired output (e.g. 3 images of shape (128, 128, 3)).
output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),
                                         strides=(1, 2, 2), padding='same',
                                         activation='relu')(output)
output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),
                                         strides=(1, 2, 2), padding='same',
                                         activation='relu')(output)  # Shape=(None, 3, 128, 128, 3)
The rationale behind transposed convolution layers is discussed here. Essentially, the Conv3DTranspose layer goes in the opposite direction of normal convolutions - it allows upsampling your low-resolution images into high-resolution images.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With