Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

AttributeError: 'tensorflow.python.ops.rnn' has no attribute 'rnn'

Tags:

tensorflow

I am following this tutorial on Recurrent Neural Networks.

This is the imports:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops import rnn
from tensorflow.contrib.rnn import core_rnn_cell

This is code for input processing:

x = tf.transpose(x, [1,0,2])
x = tf.reshape(x, [-1, chunk_size])
x = tf.split(x, n_chunks, 0)

lstm_cell = core_rnn_cell.BasicLSTMCell(rnn_size)
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

I am getting the following error for the outputs, states:

AttributeError: module 'tensorflow.python.ops.rnn' has no attribute 'rnn'

TensorFlow was updated recently, so what should be the new code for the offending line

like image 792
suku Avatar asked Feb 18 '17 04:02

suku


1 Answers

For people using the newer version of tensorflow, add this to the code:

from tensorflow.contrib import rnn 


lstm_cell = rnn.BasicLSTMCell(rnn_size) 
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

instead of

from tensorflow.python.ops import rnn, rnn_cell 
lstm_cell = rnn_cell.BasicLSTMCell(rnn_size,state_is_tuple=True) 
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

PS: @BrendanA suggested that tf.nn.rnn_cell.LSTMCell be used instead of rnn_cell.BasicLSTMCell

like image 184
suku Avatar answered Jan 03 '23 15:01

suku