Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: Low level LSTM implementation

I am looking for a low-level implementation of a RNN with LSTM cells in Tensorflow. I already implemented several feedforward networks where I used low-level APIs. This helped me a lot to understand the inner workings of ANNs. Can I do the same for a RNN or is it recommended to use the Tensorflow implementation of a LSTM cell (tf.nn.rnn_cell.BasicLSTMCell)? I didn't find any low-level implementation of a RNN in Tensorflow. Where could I find such a low-level implementation? Is Tensorflow designed for this at all? Where could I start? I hope that a few of my questions can be answered here

like image 929
Gilfoyle Avatar asked Oct 27 '22 21:10

Gilfoyle


1 Answers

1) Using tf.scan

The low-level implementation of RNN can be achieved with tf.scan function. For example, for SimpleRNN, the implementation will be similar to:

# our RNN variables
Wx = tf.get_variable(name='Wx', shape=[embedding_size, rnn_size])
Wh = tf.get_variable(name='Wh', shape=[rnn_size, rnn_size])
bias_rnn = tf.get_variable(name='brnn', initializer=tf.zeros([rnn_size]))


# single step in RNN
# simpleRNN formula is `tanh(WX+WH)`
def rnn_step(prev_hidden_state, x):
    return tf.tanh(tf.matmul(x, Wx) + tf.matmul(prev_hidden_state, Wh) + bias_rnn)

# our unroll function
# notice that our inputs should be transpose
hidden_states = tf.scan(fn=rnn_step,
                        elems=tf.transpose(embed, perm=[1, 0, 2]),
                        initializer=tf.zeros([batch_size, rnn_size]))

# covert to previous shape
outputs = tf.transpose(hidden_states, perm=[1, 0, 2])

# extract last hidden
last_rnn_output = outputs[:, -1, :]

See complete example here.

2) Using AutoGraph

tf.scan is a for-loop that you can implement it Auto-graph API as well:

from tensorflow.python import autograph as ag

@ag.convert()
def f(x):
# ...
for ch in chars:
      cell_output, (state, output) = cell.call(ch, (state, output))
      hidden_outputs.append(cell_output)
hidden_outputs = autograph.stack(hidden_outputs)
# ...

See complete example with autograph API here.

3) Implement in Numpy

If you still need to go deeper inside to implement RNN see this tutorial that implements RNN with numpy.

4) Custom RNN cell in Keras

See here.

like image 193
Amir Avatar answered Oct 31 '22 10:10

Amir