This it the code:
X = tf.placeholder(tf.float32, [batch_size, seq_len_1, 1], name='X')
labels = tf.placeholder(tf.float32, [None, alpha_size], name='labels')
rnn_cell = tf.contrib.rnn.BasicLSTMCell(512)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell] * 3, state_is_tuple=True)
pre_prediction, state = tf.nn.dynamic_rnn(m_rnn_cell, X, dtype=tf.float32)
This is full error:
ValueError: Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (1024, 2048) and found shape (513, 2048).
I'm using a GPU version of tensorflow.
I encountered a similar problem when I upgraded to v1.2 (tensorflow-gpu).
Instead of using [rnn_cell]*3
, I created 3 rnn_cells
(stacked_rnn) by a loop (so that they don't share variables) and fed MultiRNNCell
with stacked_rnn
and the problem goes away. I'm not sure it is the right way to do it.
stacked_rnn = []
for iiLyr in range(3):
stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=512, state_is_tuple=True))
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
An official TensorFlow tutorial recommends this way of multiple LSTM network definition:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(lstm_size)
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(number_of_layers)])
You can find it here: https://www.tensorflow.org/tutorials/recurrent
Actually it it almost the same approach that Wasi Ahmad and Maosi Chen suggested above but maybe in a little bit more elegant form.
I guess it's because your RNN cells on each of your 3 layers share the same input and output shape.
On layer 1, the input dimension is 513 = 1(your x dimension) + 512(dimension of the hidden layer) for each timestamp per batch.
On layer 2 and 3, the input dimension is 1024 = 512(output from previous layer) + 512 (output from previous timestamp).
The way you stack up your MultiRNNCell probably implies that 3 cells share the same input and output shape.
I stack up MultiRNNCell by declaring two separate types of cells in order to prevent them from sharing input shape
rnn_cell1 = tf.contrib.rnn.BasicLSTMCell(512)
run_cell2 = tf.contrib.rnn.BasicLSTMCell(512)
stack_rnn = [rnn_cell1]
for i in range(1, 3):
stack_rnn.append(rnn_cell2)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell(stack_rnn, state_is_tuple = True)
Then I am able to train my data without this bug. I'm not sure whether my guess is correct, but it works for me. Hope it works for you.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With