Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: Remember LSTM state for next batch (stateful LSTM)

Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (memory and hidden) states need to be remembered for the next 'batch'. For the very beginning of the inference the internal LSTM states init_c, init_h are computed given the input. These are then stored in a LSTMStateTuple object which is passed to the LSTM. During training this state is updated every timestep. However for inference I want the state to be saved in between batches, i.e. the initial states only need to be computed at the very beginning and after that the LSTM states should be saved after each 'batch' (n=1).

I found this related StackOverflow question: Tensorflow, best way to save state in RNNs?. However this only works if state_is_tuple=False, but this behavior is soon to be deprecated by TensorFlow (see rnn_cell.py). Keras seems to have a nice wrapper to make stateful LSTMs possible but I don't know the best way to achieve this in TensorFlow. This issue on the TensorFlow GitHub is also related to my question: https://github.com/tensorflow/tensorflow/issues/2838

Anyone good suggestions for building a stateful LSTM model?

inputs  = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs") targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")  num_lstm_layers = 2  with tf.variable_scope("LSTM") as scope:      lstm_cell  = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)     self.lstm  = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)      init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'     init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'     self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers      outputs = []      for step in range(seq_length):          if step != 0:             scope.reuse_variables()          # CNN features, as input for LSTM         x_t = # ...           # LSTM step through time         output, self.state = self.lstm(x_t, self.state)         outputs.append(output) 
like image 672
verified.human Avatar asked Jul 07 '16 09:07

verified.human


People also ask

What does stateful do in LSTM?

The stateful configuration resets LSTM cell memory every epoch. This configuration is most commonly used when each sequence in the training set depends on the sequence ... Get Deep Learning Quick Reference now with the O'Reilly learning platform.

Does TensorFlow support LSTM?

TensorFlow Lite also provides a way to convert user defined LSTM implementations.

What is stateful LSTM keras?

Stateful flag is Keras. All the RNN or LSTM models are stateful in theory. These models are meant to remember the entire sequence for prediction or classification tasks. However, in practice, you need to create a batch to train a model with backprogation algorithm, and the gradient can't backpropagate between batches.

What is stateful in keras?

stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.


2 Answers

I found out it was easiest to save the whole state for all layers in a placeholder.

init_state = np.zeros((num_layers, 2, batch_size, state_size))  ...  state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) 

Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

l = tf.unpack(state_placeholder, axis=0) rnn_tuple_state = tuple( [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])  for idx in range(num_layers)] ) 

RNN passes in the API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True) outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state) 

The state - variable will then be feeded to the next batch as a placeholder.

like image 141
user1506145 Avatar answered Sep 22 '22 16:09

user1506145


Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

with tf.variable_scope('decoder') as scope:     rnn_cell = tf.nn.rnn_cell.MultiRNNCell \     ([         tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),         tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)     ], state_is_tuple = True)      state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]      for t in range(TIME_STEPS):         if t:             last = y_[t - 1] if TRAINING else y[t - 1]         else:             last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))          y[t] = tf.concat(1, (y[t], last))         y[t], state = rnn_cell(y[t], state)          scope.reuse_variables() 

Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.

like image 25
chasep255 Avatar answered Sep 24 '22 16:09

chasep255