I am using dynamic_rnn to process MNIST data:
# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
forget_bias=1.0,
initializer=tf.random_normal)
# Initial state
istate = lstm.zero_state(batch_size, "float")
# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
# Output at last time point T
output_at_T = output[:, 27, :]
Full code: http://pastebin.com/bhf9MgMe
The input to the lstm is (batch_size, sequence_length, input_size)
As a result the dimensions of output_at_T
is (batch_size, sequence_length, num_units)
where num_units=200
.
I need to get the last output along the sequence_length
dimension. In the code above, this is hardcoded as 27
. However, I do not know the sequence_length
in advance as it can change from batch to batch in my application.
I tried:
output_at_T = output[:, -1, :]
but it says negative indexing is not implemented yet, and I tried using a placeholder variable as well as a constant (into which I could ideally feed the sequence_length
for a particular batch); neither worked.
Any way to implement something like this in tensorflow atm?
Have you noticed that there are two outputs from dynamic_rnn?
So from:
h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )
the last state for each element in the batch is:
final_state.h
Note that this includes the case when the length of the sequence is different for each element of the batch, as we are using the sequence_length argument.
This is what gather_nd is for!
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.gather_nd(data, indices)
return res
In your case (assuming sequence_length
is a 1-D tensor with the length of each axis 0 element):
output = extract_axis_1(output, sequence_length - 1)
Now output is a tensor of dimension [batch_size, num_cells]
.
output[:, -1, :]
works with Tensorflow 1.x now!!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With