Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Strange loss curve while training LSTM with Keras

I'm trying to train an LSTM for some a binary classification problem. When I plot loss curve after the training, there are strange picks in it. Here are some examples:

enter image description here

enter image description here

Here is the basic code

model = Sequential()
model.add(recurrent.LSTM(128, input_shape = (columnCount,1), return_sequences=True))
model.add(Dropout(0.5))
model.add(recurrent.LSTM(128, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(optimizer='adam', 
             loss='binary_crossentropy', 
             metrics=['accuracy'])

new_train = X_train[..., newaxis]

history = model.fit(new_train, y_train, nb_epoch=500, batch_size=100, 
                    callbacks = [EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, verbose=0, mode='auto'), 
                                 ModelCheckpoint(filepath="model.h5", verbose=0, save_best_only=True)],
                    validation_split=0.1)

# list all data in history
print(history.history.keys())
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

I don't understand why do that picks occur? Any ideas?

like image 216
nabroyan Avatar asked Mar 08 '23 16:03

nabroyan


1 Answers

There are many possibilities why something like this occurs:

  1. Your parameters trajectory changed its basin of attraction - this means that your system left a stable trajectory and switched to another one. This was probably due to randomization like e.g. batch sampling or dropout.

  2. LSTM instability- LSTMs are believed to be extremely unstable in terms of training. It was also reported that very often it's really time consuming for them to stabilize.

Due to the latest research (e.g. from here) I would recommend you decreasing the batch size and leaving it for more epochs. I would also try to check if e.g. topology of a network is not to complexed (or plain) in terms of amount of patterns it need to learn. I would also try switch to either GRU or SimpleRNN.

like image 109
Marcin Możejko Avatar answered Apr 01 '23 02:04

Marcin Możejko