Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to set Tensorflow dynamic_rnn, zero_state without a fixed batch_size?

According to Tensorflow's official website,(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state) zero_state has to specify a batch_size. Many examples I found use this code:

    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)

For training steps, it's okay to fix the batch size. However, when predicting, the test set might not have the same shape as the training set's batch size. For example, one batch of my training data has shape [100, 255, 128]. The batch size is 100, with 255 steps and 128 inputs. While the test set is [2000, 255, 128]. I can't predict since in dynamic_rnn(initial_state), it already set a fixed batch_size = 100. How do I fix this?

Thanks.

like image 267
David Avatar asked Jul 12 '17 04:07

David


2 Answers

You can specify the batch_size as a placeholder, not a constant. Just make sure to feed the relevant number in feed_dict, which will be different for training and for testing

Importantly, specify [] as dimensions for the placeholder, because you might get errors if you specify None, as is customary elsewhere. So something like this should work:

batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})

Obviously make sure that the batch parameter matches the shape of your inputs, which dynamic_rnn will interpret as [batch_size, seq_len, features] or [seq_len, batch_size, features] if time_major is set to True

like image 74
VS_FF Avatar answered Nov 01 '22 22:11

VS_FF


There is a fairly simple implementation. Just remove the initial_state! It is because that the initialization process may pre-allocates a batch-sized memory.

like image 1
陈狗蛋 Avatar answered Nov 01 '22 22:11

陈狗蛋