Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to restore an LSTM layer

Tags:

tensorflow

I would really appreciate it if I could get some help in saving and restoring LSTMs.

I have this LSTM layer -

# LSTM cell
cell = tf.contrib.rnn.LSTMCell(n_hidden)
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32)

outputs = tf.transpose(output, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)

# Saver function
saver = tf.train.Saver()
saver.save(sess, 'test-model')

The saver saves the model and allows me to save and restore the weights and biases of the LSTM. However, I need to restore this LSTM layer and feed it a new set of inputs.

To restore the entire model, I'm doing:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
  1. Is it possible for me to initialize an LSTM cell with the pre-trained weights and biases?

  2. If not, how do I restore this LSTM layer?

Thank you very much!

like image 421
AnnaR Avatar asked Jul 17 '17 22:07

AnnaR


1 Answers

You are already loading the model, and so the weights of the model. All you need to do is use get_tensor_by_name to get any tensor from the graph and use it for inference.

Example:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

   # Get the tensors by their variable name
   word_vec = = detection_graph.get_tensor_by_name('word_vec:0')
   output_tensor = detection_graph.get_tensor_by_name('outputs:0')

   sess.run(output_tensor, feed_dict={word_vec: ...}) 

In the above example word_vec and outputs are names assigned to the tensors during creation of the graph. Make sure you assign names, so that they can be called by their name.

like image 107
vijay m Avatar answered Sep 27 '22 19:09

vijay m