Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simple example of CuDnnGRU based RNN implementation in Tensorflow

Tags:

tensorflow

rnn

I am using the following code for standard GRU implementation:

def BiRNN_deep_dynamic_FAST_FULL_autolength(x,batch_size,dropout,hidden_dim):

seq_len=length_rnn(x)

with tf.variable_scope('forward'):
    lstm_cell_fwd =tf.contrib.rnn.GRUCell(hidden_dim,kernel_initializer=tf.contrib.layers.xavier_initializer(),bias_initializer=tf.contrib.layers.xavier_initializer())
    lstm_cell_fwd = tf.contrib.rnn.DropoutWrapper(lstm_cell_fwd, output_keep_prob=dropout)
with tf.variable_scope('backward'):
    lstm_cell_back =tf.contrib.rnn.GRUCell(hidden_dim,kernel_initializer=tf.contrib.layers.xavier_initializer(),bias_initializer=tf.contrib.layers.xavier_initializer())
    lstm_cell_back = tf.contrib.rnn.DropoutWrapper(lstm_cell_back, output_keep_prob=dropout)

outputs,_= tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fwd,cell_bw= lstm_cell_back,inputs=x,sequence_length=seq_len,dtype=tf.float32,time_major=False)
outputs_fwd,outputs_bck=outputs

### fwd matrix is the matrix that keeps all the last [-1] vectors
fwd_matrix=tf.gather_nd(outputs_fwd, tf.stack([tf.range(batch_size), seq_len-1], axis=1))       ###  99,64

outputs_fwd=tf.transpose(outputs_fwd,[1,0,2])
outputs_bck=tf.transpose(outputs_bck,[1,0,2])

return outputs_fwd,outputs_bck,fwd_matrix

Can anyone provide a simple example of how to use the tf.contrib.cudnn_rnn.CudnnGRU Cell in a similar fashion? Just swapping out the cells doesn't work.

First issue is that there is no dropout wrapper for CuDnnGRU cell, which is fine. Second it doesnt seem to work with tf.nn.bidirectional_dynamic_rnn. Any help appreciated.

like image 918
Thrabbit Avatar asked Mar 08 '18 22:03

Thrabbit


People also ask

What is RNN example?

A RNN is designed to mimic the human way of processing sequences: we consider the entire sentence when forming a response instead of words by themselves. For example, consider the following sentence: “The concert was boring for the first 15 minutes while the band warmed up but then was terribly exciting.”


1 Answers

CudnnGRU is not an RNNCell instance. It's more akin to dynamic_rnn.

The tensor manipulations below are equivalent, where input_tensor is a time-major tensor, i.e. of shape [max_sequence_length, batch_size, embedding_size]. CudnnGRU expects the input tensor to be time-major (as opposed to the more standard batch-major format i.e. of shape [batch_size, max_sequence_length, embedding_size]), and it's a good practice to use time-major tensors with RNN ops anyways since they're somewhat faster.

CudnnGRU:

rnn = tf.contrib.cudnn_rnn.CudnnGRU(
  num_rnn_layers, hidden_size, direction='bidirectional')

rnn_output = rnn(input_tensor)

CudnnCompatibleGRUCell:

rnn_output = input_tensor
sequence_length = tf.reduce_sum(
  tf.sign(inputs),
  reduction_indices=0)  # 1 if `input_tensor` is batch-major.

  for _ in range(num_rnn_layers):
    fw_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(hidden_size)
    bw_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(hidden_size)
    rnn_output = tf.nn.bidirectional_dynamic_rnn(
      fw_cell, bw_cell, rnn_output, sequence_length=sequence_length,
      dtype=tf.float32, time_major=True)[1]  # Set `time_major` accordingly

Note the following:

  1. If you were using LSTMs, you need not use CudnnCompatibleLSTMCell; you can use the standard LSTMCell. But with GRUs, the Cudnn implementation has inherently different math operations, and in particular, more weights (see the documentation).
  2. Unlike dynamic_rnn, CudnnGRU doesn't allow you to specify sequence lengths. Still, it is over an order of magnitude faster, but you will have to be careful on how you extract your outputs (e.g. if you're interested in the final hidden state of each sequence that is padded and of varying length, you will need each sequence's length).
  3. rnn_output is probably a tuple with lots of (distinct) stuff in both cases. Refer to the documentation, or just print it out, to inspect what parts of the output you need.
like image 89
Daniel Watson Avatar answered Sep 26 '22 15:09

Daniel Watson