Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: How to use CudnnLSTM with variable input length (like dynamic_rnn)?

I would like to speed up my LSTM network, but as I am using it for a OCR (where sequences have variable lenght), I can not use plain LSTM implementation. That is why I use "tf.nn.dynamic_rnn".

Based on benchmark of RNN in tensorflow (https://github.com/tensorflow/tensorflow/blob/754048a0453a04a761e112ae5d99c149eb9910dd/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py#L77), the CUDNN implementation is used for creating all model at once (it does not use "tf.nn.rnn" structure like others). I assume that it maybe impossible to use CUDNN with variable length, but maybe anybody success it?

Second this is using "tf.nn.bidirectional_dynamic_rnn", as I would like to use Bi-LSTM for OCR. But this should be resolved after implementing the first part.

Edit: It looks like "tf.contrib.cudnn_rnn.CudnnLSTM" have "bidirectional" implementation inside. So the only unknown this is that CUDNN can be used with variable input sequence.

Or maybe any working example which use 'CudnnLSTM' would be helpfull.

like image 793
melgor89 Avatar asked Oct 27 '16 07:10

melgor89


2 Answers

Just found this:

tf.contrib.cudnn_rnn.CudnnLSTM currently does not support batches with sequences of different length, thus this is normally not an option to use.

Source: http://returnn.readthedocs.io/en/latest/tf_lstm_benchmark.html

like image 146
Aidas Liaudanskas Avatar answered Nov 15 '22 09:11

Aidas Liaudanskas


TensorFlow will soon finally have support for variable sequence lengths: https://github.com/tensorflow/tensorflow/blob/2f672ee9562a452f8dbfa259a8ccec56367e9b17/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py#L389

It looks like it landed too late for 1.13, so it'll probably only be available on TensorFlow 1.14.

You can try it out today by installing the tf-nightly-gpu package and passing sequence_lengths=lengths where lenghts is a tf.int32 Tensor with shape [batch_size], containing the lengths of each sequence in your batch.

like image 42
Reuben Morais Avatar answered Nov 15 '22 10:11

Reuben Morais