Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Cannot replace LSTMBlockCell with LSTMBlockFusedCell in Python TensorFlow

Replacing LSTMBlockCell with LSTMBlockFusedCell throws a TypeError in static_rnn`. I'm using TensorFlow 1.2.0-rc1 compiled from source.

The full error message:

TypeError                                 Traceback (most recent call last)
<ipython-input-3-2986e054cb6b> in <module>()
     19     enc_cell = tf.contrib.rnn.LSTMBlockFusedCell(rnn_size)
     20     enc_layers = tf.contrib.rnn.MultiRNNCell([enc_cell] * num_layers, state_is_tuple=True)
---> 21     _, enc_state = tf.contrib.rnn.static_rnn(enc_layers, enc_input_unstacked, dtype=dtype)
     23 with tf.variable_scope('decoder'):

~/Virtualenvs/scikit/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py in static_rnn(cell, inputs, initial_state, dtype, sequence_length, scope)
   1140   if not _like_rnncell(cell):
-> 1141     raise TypeError("cell must be an instance of RNNCell")
   1142   if not nest.is_sequence(inputs):
   1143     raise TypeError("inputs must be a sequence")

TypeError: cell must be an instance of RNNCell

Code to reproduce:

import tensorflow as tf

batch_size = 8
enc_input_length = 1000

dtype = tf.float32
rnn_size = 8
num_layers = 2

enc_input = tf.placeholder(dtype, shape=[batch_size, enc_input_length, 1])
enc_input_unstacked = tf.unstack(enc_input, axis=1)

with tf.variable_scope('encoder'):
    enc_cell = tf.contrib.rnn.LSTMBlockFusedCell(rnn_size)
    enc_layers = tf.contrib.rnn.MultiRNNCell([enc_cell] * num_layers)
    _, enc_state = tf.contrib.rnn.static_rnn(enc_layers, enc_input_unstacked, dtype=dtype)

_like_rnncell looks like:

def _like_rnncell(cell):
  """Checks that a given object is an RNNCell by using duck typing."""
  conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
                hasattr(cell, "zero_state"), callable(cell)]
  return all(conditions)

Turns out LSTMBlockFusedCell doesn't have the output_size and state_size properties that LSTMBlockCell implements.

Is this a bug, or is there a way to use LSTMBlockFusedCell that I'm missing.

like image 302
Tillmann Radmer Avatar asked Jun 13 '17 12:06

Tillmann Radmer

1 Answers

LSTMBlockFusedCell is inherited from FusedRNNCell instead of RNNCell, so you cannot use standard tf.nn.static_rnn or tf.nn.dynamic_rnn in which they require RNNCell instance (as shown in your error message).

However, in the documentation, you can directly call the cell to get the complete outputs and state.

inputs = tf.placeholder(tf.float32, [time_len, batch_size, input_size])
fused_rnn_cell = tf.contrib.rnn.LSTMBlockFusedCell(num_units)

outputs, state = fused_rnn_cell(inputs, dtype=tf.float32)

# outputs shape is (time_len, batch_size, num_units)
# state: LSTMStateTuple where c shape is (batch_size, num_units)
#  and h shape is also (batch_size, num_units).

The LSTMBlockFusedCell object calls gen_lstm_ops.block_lstm internally, which should be equivalent to a normal LSTM loop.

Also, notice that the inputs to any FusedRNNCell instance should be time-major, this can be done by just transposing the tensor before calling the cell.

like image 169
Mark Dong Avatar answered Sep 18 '22 05:09

Mark Dong