Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mentioned in Ilya Sutskever et al., ICLR 2015 ?
The default approach to initializing the state of an RNN is to use a zero state. This often works well, particularly for sequence-to-sequence tasks like language modeling where the proportion of outputs that are significantly impacted by the initial state is small.
Before we get down to business, an important thing to note is that the RNN input needs to have 3 dimensions. Typically it would be batch size, the number of steps and number of features.
The tf.nn.dynamic_rnn()
or tf.nn.rnn()
operations allow to specify the initial state of the RNN using the initial_state
parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.
In TensorFlow, you can wrap tensors in tf.Variable()
to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
with tf.control_dependencies([state.assign(new_state)]):
output = tf.identity(output)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn()
to which you can provide a state saver object, but I didn't use it yet.
In addition to danijar's answer, here is the code for a LSTM, whose state is a tuple (state_is_tuple=True
). It also supports multiple layers.
We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run
in order to update the state variables with the LSTM's last hidden state.
def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
Similar to danijar's answer, we can use that to update the LSTM's state after each batch:
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
cell = tf.contrib.rnn.MultiRNNCell(cells)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
The main difference is that state_is_tuple=True
makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With