Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Translating a TensorFlow LSTM into synapticjs

I'm working on implementing an interface between a TensorFlow basic LSTM that's already been trained and a javascript version that can be run in the browser. The problem is that in all of the literature that I've read LSTMs are modeled as mini-networks (using only connections, nodes and gates) and TensorFlow seems to have a lot more going on.

The two questions that I have are:

  1. Can the TensorFlow model be easily translated into a more conventional neural network structure?

  2. Is there a practical way to map the trainable variables that TensorFlow gives you to this structure?

I can get the 'trainable variables' out of TensorFlow, the issue is that they appear to only have one value for bias per LSTM node, where most of the models I've seen would include several biases for the memory cell, the inputs and the output.

like image 632
A. Vickory Avatar asked Dec 17 '15 00:12

A. Vickory


1 Answers

Internally, the LSTMCell class stores the LSTM weights as a one big matrix instead of 8 smaller ones for efficiency purposes. It is quite easy to divide it horizontally and vertically to get to the more conventional representation. However, it might be easier and more efficient if your library does the similar optimization.

Here is the relevant piece of code of the BasicLSTMCell:

concat = linear([inputs, h], 4 * self._num_units, True)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)

The linear function does the matrix multiplication to transform the concatenated input and the previous h state into 4 matrices of [batch_size, self._num_units] shape. The linear transformation uses a single matrix and bias variables that you're referring to in the question. The result is then split into different gates used by the LSTM transformation.

If you'd like to explicitly get the transformations for each gate, you can split that matrix and bias into 4 blocks. It is also quite easy to implement it from scratch using 4 or 8 linear transformations.

like image 63
Rafał Józefowicz Avatar answered Oct 16 '22 12:10

Rafał Józefowicz