Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to parallelize stacked RNNs over multiple GPUs in TensorFlow?

Is it possible to take the output of a tf.scan operation and stream it directly to a different GPU, effectively running two stacked RNNs on two GPUs in parallel? Something like this:

cell1 = tf.nn.rnn_cell.MultiRNNCell(..)
cell2 = tf.nn.rnn_cell.MultiRNNCell(..)

with tf.device("/gpu:0"):
  ys1 = tf.scan(lambda a, x: cell1(x, a[1]), inputs,
          initializer=(tf.zeros([batch_size, state_size]), init_state))

with tf.device("/gpu:1"):
  ys2 = tf.scan(lambda a, x: cell2(x, a[1]), ys1,
          initializer=(tf.zeros([batch_size, state_size]), init_state))

Will TensorFlow automatically take care of that optimization, or will it block the graph flow until the list ys1 is finalized.

like image 718
Lenar Hoyt Avatar asked Aug 29 '16 14:08

Lenar Hoyt


People also ask

Can TensorFlow run on multiple GPUs?

TensorFlow provides strong support for distributing deep learning across multiple GPUs. TensorFlow is an open source platform that you can use to develop and train machine learning and deep learning models. TensorFlow operations can leverage both CPUs and GPUs.

Is a pythonic flexible library to distribute the workflow across multiple GPUs?

Dask is a flexible library for parallel computing in Python which makes scaling out your workflow smooth and simple.

Can RNNS be parallelized?

Parallel RNN training can lead to 10 times speedup for RNN model training. We show the efficiency and effectiveness of the proposed speedup techniques on Microsoft internal short message dictation (SMD) data set. Fig. 1 shows the infrastructure of RNN, which includes the Maximum Entropy features as proposed in [13].


1 Answers

Unfortunately, tf.scan has a "boundary" at the output, and all iterations have to complete before the output tensor can be read by the next operations. However, you can run the different levels of your lstm stack on different GPUs, and get frame parallelism within a scan. Write your own version of MultiRNNCell to use separate devices for each lstm layer.

Also you probably want to use tf.nn.dynamic_rnn instead of scan.

like image 77
Eugene Brevdo Avatar answered Oct 03 '22 02:10

Eugene Brevdo