Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does torch.distributed.barrier() work

I've read all the documentations I could find about torch.distributed.barrier(), but still having trouble understanding how it's being used in this script and would really appreciate some help.

So the official doc of torch.distributed.barrier says it "Synchronizes all processes.This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait()."

It's used in two places in the script:

First place

    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

        ... (preprocesses the data and save the preprocessed data)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier() 

Second place

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

        ... (loads the model and the vocabulary)

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

I'm having trouble relating the comment in the code to the functionality of this function stated in the official doc. How does it make sure only the first process executes the code between the two calls of torch.distributed.barrier() and why it only checks whether the local rank is 0 before the second call?

Thanks in advance!

like image 533
hlu Avatar asked Jan 15 '20 22:01

hlu


1 Answers

First you need to understand the ranks. To be brief: in a multiprocessing context we typically assume that rank 0 is the first process or base process. The other processes are then ranked differently, e.g. 1, 2, 3, totalling four processes in total.

Some operations are not necessary to be done in parallel or you just need one process to do some preprocessing or caching so that the other processes can use that data.

In your example, if the first if statement is entered by the non-base processes (rank 1, 2, 3), they will block (or "wait") because they run into the barrier. They wait there, because barrier() blocks until all processes have reached a barrier, but the base process has not reached a barrier yet.

So at this point the non-base processes (1, 2, 3) are blocked, but the base process (0) continues. The base process will do some operations (preprocess and cache data, in this case) until it reaches the second if-statement. There, the base process will run into a barrier. At this point, all processes have stopped at a barrier, meaning that all current barriers can be lifted and all processes can continue. Because the base process prepared the data, the other processes can now use that data.

Perhaps the most important thing to understand is:

  • when a process encounters a barrier it will block
  • the position of the barrier is not important (not all processes have to enter the same if-statement, for instance)
  • a process is blocked by a barrier until all processes have encountered a barrier, upon which those barriers are lifted for all processes
like image 123
Bram Vanroy Avatar answered Nov 03 '22 17:11

Bram Vanroy