Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Build (pre-trained) CNN+LSTM network with keras functional API

I want to build an LSTM on top of pre-trained CNN (VGG) to classify a video sequence. The LSTM will be fed with the features extracted by the last FC layer of VGG.

The architecture is something like:

enter image description here

I wrote the code:

def build_LSTM_CNN_net()
      from keras.applications.vgg16 import VGG16
      from keras.models import Model
      from keras.layers import Dense, Input, Flatten
      from keras.layers.pooling import GlobalAveragePooling2D, GlobalAveragePooling1D
      from keras.layers.recurrent import LSTM
      from keras.layers.wrappers import TimeDistributed
      from keras.optimizers import Nadam
    
    
      from keras.applications.vgg16 import VGG16

      num_classes = 5
      frames = Input(shape=(5, 224, 224, 3))
      base_in = Input(shape=(224,224,3))
    
      base_model = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(224,224,3))
    
      x = Flatten()(base_model.output)
      x = Dense(128, activation='relu')(x)
      x = TimeDistributed(Flatten())(x)
      x = LSTM(units = 256, return_sequences=False, dropout=0.2)(x)
      x = Dense(self.nb_classes, activation='softmax')(x)
    
lstm_cnn = build_LSTM_CNN_net()
keras.utils.plot_model(lstm_cnn, "lstm_cnn.png", show_shapes=True)

But got the error:

ValueError: `TimeDistributed` Layer should be passed an `input_shape ` with at least 3 dimensions, received: [None, 128]

Why is this happening, how can I fix it?

like image 967
okuoub Avatar asked Nov 24 '25 10:11

okuoub


1 Answers

here the correct way to build a model to classify video sequences. Note that I wrap into TimeDistributed a model instance. This model was previously build to extract features from each frame individually. In the second part, we deal the frame sequences

frames, channels, rows, columns = 5,3,224,224

video = Input(shape=(frames,
                     rows,
                     columns,
                     channels))
cnn_base = VGG16(input_shape=(rows,
                              columns,
                              channels),
                 weights="imagenet",
                 include_top=False)
cnn_base.trainable = False

cnn_out = GlobalAveragePooling2D()(cnn_base.output)
cnn = Model(cnn_base.input, cnn_out)
encoded_frames = TimeDistributed(cnn)(video)
encoded_sequence = LSTM(256)(encoded_frames)
hidden_layer = Dense(1024, activation="relu")(encoded_sequence)
outputs = Dense(10, activation="softmax")(hidden_layer)

model = Model(video, outputs)
model.summary()

if you want to use the VGG 1x4096 emb representation you can simply do:

frames, channels, rows, columns = 5,3,224,224

video = Input(shape=(frames,
                     rows,
                     columns,
                     channels))
cnn_base = VGG16(input_shape=(rows,
                              columns,
                              channels),
                 weights="imagenet",
                 include_top=True) #<=== include_top=True
cnn_base.trainable = False

cnn = Model(cnn_base.input, cnn_base.layers[-3].output) # -3 is the 4096 layer
encoded_frames = TimeDistributed(cnn)(video)
encoded_sequence = LSTM(256)(encoded_frames)
hidden_layer = Dense(1024, activation="relu")(encoded_sequence)
outputs = Dense(10, activation="softmax")(hidden_layer)

model = Model(video, outputs)
model.summary()
like image 162
Marco Cerliani Avatar answered Nov 27 '25 00:11

Marco Cerliani



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!