Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensor Flow - LSTM - 'Tensor' object not iterable

Hi I am using the following function for lstm rnn cell.

def LSTM_RNN(_X, _istate, _weights, _biases):
    # Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters. 
    # Note, some code of this notebook is inspired from an slightly different 
    # RNN architecture used on another dataset: 
    # https://tensorhub.com/aymericdamien/tensorflow-rnn

    # (NOTE: This step could be greatly optimised by shaping the dataset once
    # input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size

    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)

    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define a lstm cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)


    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)

    # Get lstm cell output
    outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)

    # Linear activation
    # Get inner loop last output
    return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

The function's output is stored under pred variable.

pred = LSTM_RNN(x, istate, weights, biases)

But its showing the following error. (which states that tensor object is not iterable.)

Here is the ERROR image link - http://imgur.com/a/NhSFK

Please help me with this and I apologize if this question seems silly as I am fairly new to the lstm and tensor flow library.

Thanks.

like image 837
Daniel Fox Avatar asked Nov 07 '16 11:11

Daniel Fox


2 Answers

The error happened when it's trying to unpack state with statement c, h=state. Depending on which version of tensorflow you are using (you can check the version info by typing import tensorflow; tensorflow.__version__ in python interpreter), in version prior to r0.11, the default setting for the state_is_tuple argument when you initialize the rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0) is set to be False. See the documentation here.

BasicLSTMCell documentation in r0.10

Since tensorflow version r0.11 (or the master version), the default setting for state_is_tuple is set to be True. See the documentation here.

BasicLSTMCell documentation in r0.11

If you installed r0.11 or the master version of tensorflow, try change the BasicLSTMCell initialization line into: lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False). The error you are encountering should go away. Although, their page does say that the state_is_tuple=False behavior will be deprecated soon.

BasicLSTMCell state_is_tuple argument documentation

like image 168
Zhongyu Kuang Avatar answered Nov 15 '22 15:11

Zhongyu Kuang


I happened to met the same question at same time. I just describe my circumstance which may do help for u

it state as follow

c1_ex, T1_ex = tf. ones(10,tf. int 32)
 raise Type Error ...

I find that left side of '=' have been set two name of vector in advance

while the other side just return a vector

sorry for my inefficiency of English

your problem actually appear in line 146 not the line 193

like image 3
user8097598 Avatar answered Nov 15 '22 14:11

user8097598