Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow variable reuse

I've built my LSTM model. Ideally I want to use reuse the variables to define a test LSTM model later.

with tf.variable_scope('lstm_model') as scope:
    # Define LSTM Model
    lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size)
    scope.reuse_variables()
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size, infer=True)

The code above gives me an error

Variable lstm_model/lstm_vars/W already exists, disallowed. Did you mean to set reuse=True in VarScope? 

If I set the reuse=True as shown in the below code block

with tf.variable_scope('lstm_model', reuse=True) as scope:

I get a different error

Variable lstm_model/lstm_model/lstm_vars/W/Adam/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

As a reference I've appended the relevant model code below. The corresponding section in the LSTM model where I have the weights

with tf.variable_scope('lstm_vars'):
    # Softmax Output Weights
    W = tf.get_variable('W', [self.rnn_size, self.vocab_size], tf.float32, tf.random_normal_initializer())

The corresponding section where I have my Adam optimizer:

optimizer = tf.train.AdamOptimizer(self.learning_rate)
like image 660
caramelslice Avatar asked Dec 03 '22 21:12

caramelslice


2 Answers

It seems like instead of:

with tf.variable_scope('lstm_model') as scope:
    # Define LSTM Model
    lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size)
    scope.reuse_variables()    
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                     training_seq_len, vocab_size, infer_sample=True)

This fixes the issue

# Define LSTM Model
lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                        training_seq_len, vocab_size)

# Tell TensorFlow we are reusing the scope for the testing
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    test_lstm_model = LSTM_Model(rnn_size, batch_size, learning_rate,
                                 training_seq_len, vocab_size, infer_sample=True)
like image 51
caramelslice Avatar answered Mar 19 '23 19:03

caramelslice


If you use one variable twice (or more times), you should first time use with tf.variable_scope('scope_name', reuse=False): and next times with tf.variable_scope('scope_name', reuse=True):.

Or you can use method tf.variable_scope.reuse_variables()

with tf.variable_scope("foo") as scope:
    v = tf.get_variable("v", [1])
    scope.reuse_variables()
    v1 = tf.get_variable("v", [1])

in code above v and v1 are the same variable.

like image 34
Vladimir Bystricky Avatar answered Mar 19 '23 21:03

Vladimir Bystricky