Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow dynamic_rnn parameters meaning

Tags:

tensorflow

I'm struggling to understand the cryptic RNN docs. Any help with the following will be greatly appreciated.

tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None)

I'm struggling to understand how these parameters relate to the mathematical LSTM equations and RNN definition. Where is the cell unroll size? Is it defined by the 'max_time' dimension of the inputs? Is the batch_size only a convenience for splitting long data or it's related to minibatch SGD? Is the output state passed across batches?

like image 835
Anton Avatar asked Jan 27 '17 00:01

Anton


1 Answers

tf.nn.dynamic_rnn takes in a batch (with the minibatch meaning) of unrelated sequences.

  • cell is the actual cell that you want to use (LSTM, GRU,...)
  • inputs has a shape of batch_size x max_time x input_size in which max_time is the number of steps in the longest sequence (but all sequences could be of the same length)
  • sequence_length is a vector of size batch_size in which each element gives the length of each sequence in the batch (leave it as default if all your sequences are of the same size. This parameter is the one that defines the cell unroll size.

Hidden state handling

The usual way of handling hidden state is to define an initial state tensor before the dynamic_rnn, like this for instance :

hidden_state_in = cell.zero_state(batch_size, tf.float32) 
output, hidden_state_out = tf.nn.dynamic_rnn(cell, 
                                             inputs,
                                             initial_state=hidden_state_in,
                                             ...)

In the above snippet, both hidden_state_in and hidden_state_out have the same shape [batch_size, ...] (the actual shape depends on the type of cell you use but the important thing is that the first dimension is the batch size).

This way, dynamic_rnn has an initial hidden state for each sequence. It will pass on the hidden state from time step to time step for each sequence in the inputs parameter on its own, and hidden_state_out will contain the final output state for each sequence in the batch. No hidden state is passed between sequences of the same batch, but only between time steps of the same sequence.

When do I need to feed back the hidden state manually?

Usually, when you're training, every batch is unrelated so you don't have to feed back the hidden state when doing a session.run(output).

However, if you're testing, and you need the output at each time step, (i.e. you have to do a session.run() at every time step) you'll want to evaluate and feed back the output hidden state using something like this :

output, hidden_state = sess.run([output, hidden_state_out],
                                feed_dict={hidden_state_in:hidden_state})

otherwise tensorflow will just use the default cell.zero_state(batch_size, tf.float32) at each time step which equates to reinitialising the hidden state at each time step.

like image 77
Florentin Hennecker Avatar answered Oct 26 '22 07:10

Florentin Hennecker