Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

single-step simulation in tensorflow RNN

I am working on a RNN controller, which takes the current state of the plant as the input to the RNN, and generates the output as the controlling signal . After executing the control, the updated plant state is fed back to the RNN as the input of next time step. In this looping, the input sequence is stacked step by step, rather than all given in advance. For now, no training is involved. Only the single-step forward simulation is needed. So a tensorflow RNN operation that can do this one-step RNN output is what I'm looking for.

input_data = tf.placeholder(tf.float32, [batch_size, len_seq,8])

I defined two kinds of input: Input_data for the batch_size sequences of input, and input_single for the input of current time step.

input_single = tf.placeholder(tf.float32, [1, 1, 8])
action_gradient = tf.placeholder(tf.float32, [batch_size, len_seq, dimAction])
num_hidden = 24    
cell = tf.nn.rnn_cell.LSTMCell(num_hidden, state_is_tuple=True)
state_single = cell.zero_state(batch_size, tf.float32)
(output_single, state_single) = cell(input_single, state_single)
weight = tf.Variable(tf.truncated_normal([num_hidden, dimAction]))
bias = tf.Variable(tf.constant(0.1, shape=[dimAction]))
y_single = tf.nn.tanh(tf.matmul(output_single, weight) + bias)

The network is read out in two ways: y_single for each time step, and y_seq for the whole minibatch of the input.

outputs, states = tf.nn.dynamic_rnn(cell, input_data, dtype=tf.float32)
y_seq = tf.nn.tanh(tf.matmul(outputs, weight) + bias)
like image 828
H. Shi Avatar asked Dec 13 '16 14:12

H. Shi


1 Answers

You can achieve this by simply calling your tf.rnn.LSTMCell object once. Make sure you put correct arguments. Something like this will help you,

cell = tf.nn.rnn_cell.LSTMCell(num_hidden, state_is_tuple=True)
input_single = tf.ones([batch_size, input_size])
state_single = cell.zero_state(batch_size, tf.float32)
(output_single, state_single) = cell(input_single, state_single)

Have a look at the documentation for RNNCell.__call__() for more details on what the shape of input_single and state_single should be, if you have a good reason not to use cell.zero_state().

like image 92
martianwars Avatar answered Nov 19 '22 21:11

martianwars