Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow cifar synchronization point

Reading https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py for function average_gradients the following comment is provided : Note that this function provides a synchronization point across all towers. Is function average_gradients a blocking call and is what is meant by synchronization point ?

I assume this is a blocking call as in order to calculate the average of gradients each gradient must be individually computed ? But where is the blocking code that wait's for all individual gradient computes ?

like image 445
blue-sky Avatar asked May 03 '17 17:05

blue-sky


1 Answers

The average_gradients by itself is not a blocking function. It could have been another function with tensorflow operations and this would still be a synchronization point. What makes it blocking is the fact that it uses the argument tower_grads which depends on all the graphs created in the previous for loop.

Basically what happens here is the creation of the training graph. First, in the for loop for i in xrange(FLAGS.num_gpus) several graph "threads" are created. Each looks like this:

calculate loss --> calculate gradients --> append to tower_grads

Each of those graph "threads" is assigned to a different gpu through with tf.device('/gpu:%d' % i) and each one can run independent from each other (and will later run in parallel). Now the next time tower_grads is used without a device specification, it creates a graph continuation on the main device, binding all those separate graph "threads" into a single one. Tensorflow will make sure that every graph "thread" that is part of the creation of tower_grads is finished before running the graph inside the average_gradients function. Therefore later when sess.run([train_op, loss]) is called, this will be the synchronization point of the graph.

like image 122
BlueSun Avatar answered Oct 02 '22 00:10

BlueSun