Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch mutiprocessing: Do I need to use Lock() when accessing a shared model?

I have some questions about using the torch.multiprocessing module. Let’s say I have a torch.nn.Module called model and I call model.share_memory() on it.

What happens if two threads call the forward(), i.e. model(input) at the same time? Is it safe? Or should I use Lock mechanisms to be sure that model is not accessed at the same time by multiple threads? Similarly, what happens if two or more threads have an optimizer working on model.parameters() and they call optimizer.step() at the same time?

I ask these questions because I often see the optimizer.step() being called on shared models without lock mechanisms (i.e. in RL implementations of A3C or ACER) and I wonder if it is a safe thing to do.

like image 744
Federico Taschin Avatar asked Dec 03 '20 18:12

Federico Taschin


People also ask

How does torch multiprocessing work?

torch. multiprocessing is a drop in replacement for Python's multiprocessing module. It supports the exact same operations, but extends it, so that all tensors sent through a multiprocessing. Queue , will have their data moved into shared memory and will only send a handle to another process.

Does multiprocessing work in Python?

The Python language allows for something called multiprocess, a term that describes the act of running many processes simultaneously. With it, you can write a program and assign many tasks to be completed at the same time, saving time and energy.

What is Python multiprocessing?

multiprocessing is a package that supports spawning processes using an API similar to the threading module. The multiprocessing package offers both local and remote concurrency, effectively side-stepping the Global Interpreter Lock by using subprocesses instead of threads.


1 Answers

It doesn't have to be safe, since they are running asynchronously not in parallel. Quoting from the docs,

Using torch.multiprocessing, it is possible to train a model asynchronously, with parameters either shared all the time, or being periodically synchronized. In the first case, we recommend sending over the whole model object, while in the latter, we advise to only send the state_dict().

like image 158
ndrwnaguib Avatar answered Oct 19 '22 18:10

ndrwnaguib