Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is LayerNormBasicLSTMCell much slower and less accurate than LSTMCell?

I recently found that LayerNormBasicLSTMCell is a version of LSTM with Layer Normalization and dropout implemented. Therefore, I replaced the my original code using LSTMCell with LayerNormBasicLSTMCell. Not only did this change reduce the test accuracy from ~96% to ~92%, it took much longer time (~33 hours) to train (original training time is ~6 hours). All parameters are the same: number of epochs (10), number of stacked layers (3), number of hidden vector size (250), drop out keep prob (0.5), ... The hardware is also the same.

My question is: What did I do wrong here?

My original model (using LSTMCell):

# Batch normalization of the raw input
tf_b_VCCs_AMs_BN1 = tf.layers.batch_normalization(
    tf_b_VCCs_AMs, # the input vector, size [#batches, #time_steps, 2]
    axis=-1, # axis that should be normalized 
    training=Flg_training, # Flg_training = True during training, and False during test
    trainable=True,
    name="Inputs_BN"
    )

# Bidirectional dynamic stacked LSTM

##### The part I changed in the new model (start) #####
dropcells = []
for iiLyr in range(3):
    cell_iiLyr = tf.nn.rnn_cell.LSTMCell(num_units=250, state_is_tuple=True)
    dropcells.append(tf.nn.rnn_cell.DropoutWrapper(cell=cell_iiLyr, output_keep_prob=0.5))
##### The part I changed in the new model (end) #####

MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=dropcells, state_is_tuple=True)

outputs, states  = tf.nn.bidirectional_dynamic_rnn(
    cell_fw=MultiLyr_cell, 
    cell_bw=MultiLyr_cell,
    dtype=tf.float32,
    sequence_length=tf_b_lens, # the actual lengths of the input sequences (tf_b_VCCs_AMs_BN1)
    inputs=tf_b_VCCs_AMs_BN1,
    scope = "BiLSTM"
    )

My new model (using LayerNormBasicLSTMCell):

...
dropcells = []
for iiLyr in range(3):
    cell_iiLyr = tf.contrib.rnn.LayerNormBasicLSTMCell(
        num_units=250,
        forget_bias=1.0,
        activation=tf.tanh,
        layer_norm=True,
        norm_gain=1.0,
        norm_shift=0.0,
        dropout_keep_prob=0.5
        )
    dropcells.append(cell_iiLyr)
...
like image 992
Maosi Chen Avatar asked Jul 17 '17 17:07

Maosi Chen


2 Answers

Perhaps dropout_keep_prob should be assigned a placeholder instead of a constant value. Try to assign 0.5 at training and 1.0 at inference. Just a guess.

like image 166
carusyte Avatar answered Dec 16 '22 20:12

carusyte


About the training time: I came across this blog post: http://olavnymoen.com/2016/07/07/rnn-batch-normalization. See the last figure. The batch normalized lstm was more than 3 times slower than the vanilla lstm. The writer argues that the reason is the batch statistics computation.

About the accuracy: I have no idea.

like image 22
Fariborz Ghavamian Avatar answered Dec 16 '22 21:12

Fariborz Ghavamian