Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement Tensorflow batch normalization in LSTM

My current LSTM network looks like this.

rnn_cell = tf.contrib.rnn.BasicRNNCell(num_units=CELL_SIZE)
init_s = rnn_cell.zero_state(batch_size=1, dtype=tf.float32)  # very first hidden state
outputs, final_s = tf.nn.dynamic_rnn(
    rnn_cell,              # cell you have chosen
    tf_x,                  # input
    initial_state=init_s,  # the initial hidden state
    time_major=False,      # False: (batch, time step, input); True: (time step, batch, input)
)

# reshape 3D output to 2D for fully connected layer
outs2D = tf.reshape(outputs, [-1, CELL_SIZE])
net_outs2D = tf.layers.dense(outs2D, INPUT_SIZE)

# reshape back to 3D
outs = tf.reshape(net_outs2D, [-1, TIME_STEP, INPUT_SIZE])

Usually, I apply tf.layers.batch_normalization as batch normalization. But I am not sure if this works in a LSTM network.

b1 = tf.layers.batch_normalization(outputs, momentum=0.4, training=True)
d1 = tf.layers.dropout(b1, rate=0.4, training=True)

# reshape 3D output to 2D for fully connected layer
outs2D = tf.reshape(d1, [-1, CELL_SIZE])                       
net_outs2D = tf.layers.dense(outs2D, INPUT_SIZE)

# reshape back to 3D
outs = tf.reshape(net_outs2D, [-1, TIME_STEP, INPUT_SIZE])
like image 669
BenjiBB Avatar asked Oct 24 '17 16:10

BenjiBB


People also ask

Does batch normalization work with Lstm?

The batch-normalized LSTM is able to converge faster relatively to a baseline LSTM. Batch-normalized LSTM also shows some improve generalization on the permuted sequential MNIST that require to preserve long-term memory information.

What is batch normalization in Lstm?

Batch normalization applied to RNNs is similar to batch normalization applied to CNNs: you compute the statistics in such a way that the recurrent/convolutional properties of the layer still hold after BN is applied.

What is batch normalization in Tensorflow?

Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1. Importantly, batch normalization works differently during training and during inference.

How is batch normalization implemented?

Batch normalization can be implemented during training by calculating the mean and standard deviation of each input variable to a layer per mini-batch and using these statistics to perform the standardization.


1 Answers

If you want to use batch norm for RNN (LSTM or GRU), you can check out this implementation , or read the full description from blog post.

However, the layer-normalization has more advantage than batch norm in sequence data. Specifically, "the effect of batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to recurrent networks" (from the paper Ba, et al. Layer normalization).

For layer normalization, it normalizes the summed inputs within each layer. You can check out the implementation of layer-normalization for GRU cell:

like image 158
ngovanmao Avatar answered Oct 02 '22 16:10

ngovanmao