I'm training an LSTM cell on batches of sequences that have different lengths. The tf.nn.rnn
has the very convenient parameter sequence_length
, but after calling it, I don't know how to pick the output rows corresponding the last time step of each item in the batch.
My code is basically as follows:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
is a list with the LSTM output at each time step. However, each item in my batch has a different length, and so I would like to create a tensor containing the last LSTM output valid for each item in my batch.
If I could use numpy indexing, I would just do something like this:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
But it turns out that for the time begin tensorflow doesn't support it (I'm aware of the feature request).
So, how could I get these values?
A more acceptable workaround was published by danijar on the feature request page I linked in the question. It doesn't need to evaluate the tensors, which is a big plus.
I got it to work with tensorflow 0.8. Here is the code:
def extract_last_relevant(outputs, length):
"""
Args:
outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
activations of each in the batch for each time step as returned by
tensorflow.models.rnn.rnn.
length: Tensor(batch_size): The used sequence length of each example in the
batch with all later time steps being zeros. Should be of type tf.int32.
Returns:
Tensor(batch_size, output_neurons): The last relevant output activation for
each example in the batch.
"""
output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
# Query shape.
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
num_neurons = int(output.get_shape()[2])
# Index into flattened array as a workaround.
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, num_neurons])
relevant = tf.gather(flat, index)
return relevant
It's not be the nicest solution but you could evaluate your outputs then just use numpy indexing to get the results and create a tensor variable from that? It might work as a stop gap until tensorflow gets this feature. e.g.
all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
if you're only interested in the last valid output you can retrieve it through the state returned by tf.nn.rnn()
considering that it's always a tuple (c, h) where c is the last state and h is the last output. When the state is a LSTMStateTuple
you can use the following snippet (working in tensorflow 0.12):
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With