Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is RNN initial state reset for subsequent mini-batches?

Could someone please clarify whether the initial state of the RNN in TF is reset for subsequent mini-batches, or the last state of the previous mini-batch is used as mentioned in Ilya Sutskever et al., ICLR 2015 ?

like image 200
VM_AI Avatar asked Jul 18 '16 16:07

VM_AI


People also ask

What is initial state in RNN?

The default approach to initializing the state of an RNN is to use a zero state. This often works well, particularly for sequence-to-sequence tasks like language modeling where the proportion of outputs that are significantly impacted by the initial state is small.

How many dimensions are required for input of an RNN layer?

Before we get down to business, an important thing to note is that the RNN input needs to have 3 dimensions. Typically it would be batch size, the number of steps and number of features.


Video Answer


2 Answers

The tf.nn.dynamic_rnn() or tf.nn.rnn() operations allow to specify the initial state of the RNN using the initial_state parameter. If you don't specify this parameter, the hidden states will be initialized to zero vectors at the beginning of each training batch.

In TensorFlow, you can wrap tensors in tf.Variable() to keep their values in the graph between multiple session runs. Just make sure to mark them as non-trainable because the optimizers tune all trainable variables by default.

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))

cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)

with tf.control_dependencies([state.assign(new_state)]):
    output = tf.identity(output)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})

I haven't tested this code but it should give you a hint in the right direction. There is also a tf.nn.state_saving_rnn() to which you can provide a state saver object, but I didn't use it yet.

like image 161
danijar Avatar answered Sep 28 '22 02:09

danijar


In addition to danijar's answer, here is the code for a LSTM, whose state is a tuple (state_is_tuple=True). It also supports multiple layers.

We define two functions - one for getting the state variables with an initial zero state and one function for returning an operation, which we can pass to session.run in order to update the state variables with the LSTM's last hidden state.

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

Similar to danijar's answer, we can use that to update the LSTM's state after each batch:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
cell = tf.contrib.rnn.MultiRNNCell(cells)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

The main difference is that state_is_tuple=True makes the LSTM's state a LSTMStateTuple containing two variables (cell state and hidden state) instead of just a single variable. Using multiple layers then makes the LSTM's state a tuple of LSTMStateTuples - one per layer.

like image 44
Kilian Batzner Avatar answered Sep 28 '22 04:09

Kilian Batzner