Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ValueError: Attempt to reuse RNNCell with a different variable scope than its first use

The following code fragment

import tensorflow as tf
from tensorflow.contrib import rnn

hidden_size = 100
batch_size  = 100
num_steps   = 100
num_layers  = 100
is_training = True
keep_prob   = 0.4

input_data = tf.placeholder(tf.float32, [batch_size, num_steps])
lstm_cell = rnn.BasicLSTMCell(hidden_size, forget_bias=0.0, state_is_tuple=True)

if is_training and keep_prob < 1:
    lstm_cell = rnn.DropoutWrapper(lstm_cell)
cell = rnn.MultiRNNCell([lstm_cell for _ in range(num_layers)], state_is_tuple=True)

_initial_state = cell.zero_state(batch_size, tf.float32)

iw = tf.get_variable("input_w", [1, hidden_size])
ib = tf.get_variable("input_b", [hidden_size])
inputs = [tf.nn.xw_plus_b(i_, iw, ib) for i_ in tf.split(input_data, num_steps, 1)]

if is_training and keep_prob < 1:
    inputs = [tf.nn.dropout(input_, keep_prob) for input_ in inputs]

outputs, states = rnn.static_rnn(cell, inputs, initial_state=_initial_state)

produces the following error:

ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x10210d5c0> with a different variable scope than its first use. First use of cell was with scope 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell', this attempt is with scope `'rnn/multi_rnn_cell/cell_1/basic_lstm_cell'``.

Please create a new instance of the cell if you would like it to use a different set of weights.

If before you were using: MultiRNNCell([BasicLSTMCell(...)] * num_layers), change to: MultiRNNCell([BasicLSTMCell(...) for _ in range(num_layers)]).

If before you were using the same cell instance as both the forward and reverse cell of a bidirectional RNN, simply create two instances (one for forward, one for reverse).

In May 2017, we will start transitioning this cell's behavior to use existing stored weights, if any, when it is called with scope=None (which can lead to silent model degradation, so this error will remain until then.)

How to solve this problem?

My version of Tensorflow is 1.0.

like image 398
Douglas Huang Avatar asked Mar 08 '17 11:03

Douglas Huang


1 Answers

As suggested in the comments my solution is:
changing this

cell = tf.contrib.rnn.LSTMCell(state_size, state_is_tuple=True)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=0.8)
rnn_cells = tf.contrib.rnn.MultiRNNCell([cell for _ in range(num_layers)], state_is_tuple = True)
outputs, current_state = tf.nn.dynamic_rnn(rnn_cells, x, initial_state=rnn_tuple_state, scope = "layer")

into:

def lstm_cell():
    cell = tf.contrib.rnn.LSTMCell(state_size, reuse=tf.get_variable_scope().reuse)
    return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=0.8)

rnn_cells = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(num_layers)], state_is_tuple = True)
outputs, current_state = tf.nn.dynamic_rnn(rnn_cells, x, initial_state=rnn_tuple_state)

which seems to solve the reusability problem. I don't fundamentally understand the underlying problem, but this solved the issue for me on TF 1.1rc2
cheers!

like image 140
dv3 Avatar answered Oct 22 '22 14:10

dv3