Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

I have a setup where I need to initialize an LSTM after the main initialization which uses tf.initialize_all_variables(). I.e. I want to call tf.initialize_variables([var_list])

Is there way to collect all the internal trainable variables for both:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

so that I can initialize JUST these parameters?

The main reason I want this is because I do not want to re-initialize some trained values from earlier.

like image 563
bge0 Avatar asked Jan 26 '16 11:01

bge0


2 Answers

The easiest way to solve your problem is to use variable scope. The names of the variables within a scope will be prefixed with its name. Here is a short snippet:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

It would work the same way with MultiRNNCell.

EDIT: changed tf.trainable_variables to tf.all_variables()

like image 146
Rafał Józefowicz Avatar answered Nov 15 '22 19:11

Rafał Józefowicz


You can also use tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(partly copied from Rafal's answer)

Note that the last line is equivalent to the list comprehension in Rafal's code.

Basically, tensorflow stores a global collection of variables, which can be fetched by either tf.all_variables() or tf.get_collection(tf.GraphKeys.VARIABLES). If you specify scope (scope name) in the tf.get_collection() function, then you only fetch tensors (variables in this case) in the collection whose scopes are under the specified scope.

EDIT: You can also use tf.GraphKeys.TRAINABLE_VARIABLES to get trainable variables only. But since vanilla BasicLSTMCell does not initialize any non-trainable variable, both will be functionally equivalent. For a complete list of default graph collections, check this out.

like image 31
Minjoon Seo Avatar answered Nov 15 '22 20:11

Minjoon Seo