Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get last output of dynamic_rnn in tensorflow?

I am using dynamic_rnn to process MNIST data:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
                         forget_bias=1.0,
                         initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]

Full code: http://pastebin.com/bhf9MgMe

The input to the lstm is (batch_size, sequence_length, input_size)

As a result the dimensions of output_at_T is (batch_size, sequence_length, num_units) where num_units=200.

I need to get the last output along the sequence_length dimension. In the code above, this is hardcoded as 27. However, I do not know the sequence_length in advance as it can change from batch to batch in my application.

I tried:

output_at_T = output[:, -1, :]

but it says negative indexing is not implemented yet, and I tried using a placeholder variable as well as a constant (into which I could ideally feed the sequence_length for a particular batch); neither worked.

Any way to implement something like this in tensorflow atm?

like image 304
applecider Avatar asked Apr 23 '16 23:04

applecider


3 Answers

Have you noticed that there are two outputs from dynamic_rnn?

  1. Output 1, let's call it h, has all outputs at each time steps (i.e. h_1, h_2, etc),
  2. Output 2, final_state, has two elements: the cell_state, and the last output for each element of the batch (as long as you input the sequence length to dynamic_rnn).

So from:

h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )

the last state for each element in the batch is:

final_state.h

Note that this includes the case when the length of the sequence is different for each element of the batch, as we are using the sequence_length argument.

like image 107
Escachator Avatar answered Oct 24 '22 21:10

Escachator


This is what gather_nd is for!

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

In your case (assuming sequence_length is a 1-D tensor with the length of each axis 0 element):

output = extract_axis_1(output, sequence_length - 1)

Now output is a tensor of dimension [batch_size, num_cells].

like image 28
Alex Avatar answered Oct 24 '22 21:10

Alex


output[:, -1, :]

works with Tensorflow 1.x now!!

like image 3
Philippe Remy Avatar answered Oct 24 '22 20:10

Philippe Remy