Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to pick the last valid output values from tensorflow RNN

I'm training an LSTM cell on batches of sequences that have different lengths. The tf.nn.rnn has the very convenient parameter sequence_length, but after calling it, I don't know how to pick the output rows corresponding the last time step of each item in the batch.

My code is basically as follows:

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)

lstm_outputs is a list with the LSTM output at each time step. However, each item in my batch has a different length, and so I would like to create a tensor containing the last LSTM output valid for each item in my batch.

If I could use numpy indexing, I would just do something like this:

all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]

But it turns out that for the time begin tensorflow doesn't support it (I'm aware of the feature request).

So, how could I get these values?

like image 750
erickrf Avatar asked Mar 07 '16 04:03

erickrf


3 Answers

A more acceptable workaround was published by danijar on the feature request page I linked in the question. It doesn't need to evaluate the tensors, which is a big plus.

I got it to work with tensorflow 0.8. Here is the code:

def extract_last_relevant(outputs, length):
    """
    Args:
        outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
            activations of each in the batch for each time step as returned by
            tensorflow.models.rnn.rnn.
        length: Tensor(batch_size): The used sequence length of each example in the
            batch with all later time steps being zeros. Should be of type tf.int32.

    Returns:
        Tensor(batch_size, output_neurons): The last relevant output activation for
            each example in the batch.
    """
    output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
    # Query shape.
    batch_size = tf.shape(output)[0]
    max_length = int(output.get_shape()[1])
    num_neurons = int(output.get_shape()[2])
    # Index into flattened array as a workaround.
    index = tf.range(0, batch_size) * max_length + (length - 1)
    flat = tf.reshape(output, [-1, num_neurons])
    relevant = tf.gather(flat, index)
    return relevant
like image 137
erickrf Avatar answered Sep 23 '22 23:09

erickrf


It's not be the nicest solution but you could evaluate your outputs then just use numpy indexing to get the results and create a tensor variable from that? It might work as a stop gap until tensorflow gets this feature. e.g.

all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
like image 34
Daniel Slater Avatar answered Sep 22 '22 23:09

Daniel Slater


if you're only interested in the last valid output you can retrieve it through the state returned by tf.nn.rnn() considering that it's always a tuple (c, h) where c is the last state and h is the last output. When the state is a LSTMStateTuple you can use the following snippet (working in tensorflow 0.12):

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]
like image 22
learningTask Avatar answered Sep 25 '22 23:09

learningTask