Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: getting all states from a RNN

How do you get all the hidden states from tf.nn.rnn() or tf.nn.dynamic_rnn() in TensorFlow? The API only gives me the final state.

The first alternative would be to write a loop when building a model that operates directly on RNNCell. However, the number of timesteps is not fixed for me, and depends on the incoming batch.

Some options are to either use a GRU or to write my own RNNCell that concatenates the state to the output. The former choice isn't general enough and the latter sounds too hacky.

Another option is to do something like the answers in this question, getting all the variables from an RNN. However, I'm not sure how to separate the hidden states from other variables in a standard fashion here.

Is there a nice way to get all the hidden states from an RNN while still using the library-provided RNN APIs?

like image 363
Ankit Vani Avatar asked Sep 27 '16 04:09

Ankit Vani


People also ask

What is the output of a RNN?

Outputs and states A RNN layer can also return the entire sequence of outputs for each sample (one vector per timestep per sample), if you set return_sequences=True . The shape of this output is (batch_size, timesteps, units) .

Does RNN have hidden state?

The hidden state in a RNN is basically just like a hidden layer in a regular feed-forward network - it just happens to also be used as an additional input to the RNN at the next time step. Where f is some non-linear function, Wxh is a weight matrix of size h×x, and Whh is a weight matrix of size h×h.

Is RNN better than LSTM?

It difficult to train RNN that requires long-term memorization meanwhile LSTM performs better in these kinds of datasets it has more additional special units that can hold information longer. LSTM includes a 'memory cell' that can maintain information in memory for long periods of time.


1 Answers

tf.nn.dynamic_rnn(also tf.nn.static_rnn) has two return values; "outputs", "state" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)

As you said, "state" is the final state of RNN, but "outputs" are all hidden states of RNN(which shape is [batch_size, max_time, cell.output_size])

You can use "outputs" as hidden states of RNN, because in most library-provided RNNCell, "output" and "state" are same. (except LSTMCell)

  • Basic https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347
  • GRU https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441
like image 152
Junyeop Lee Avatar answered Sep 20 '22 16:09

Junyeop Lee