I've encountered a mysterious bug while trying to implement Hogwild with torch.multiprocessing. In particular, one version of the code runs fine, but when I add in a seemingly unrelated bit of code before the multiprocessing step, this somehow causes an error during the multiprocessing step: RuntimeError: Unable to handle autograd's threading in combination with fork-based multiprocessing. See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork
I reproduced the error in a minimal code sample, pasted below. If I comment out the two lines of code m0 = Model(); train(m0)
which carry out a non-parallel training run on a separate model instance, then everything runs fine. I can't figure out how these lines could be causing a problem.
I'm running PyTorch 1.5.1 and Python 3.7.6 on a Linux machine, training on CPU only.
import torch
import torch.multiprocessing as mp
from torch import nn
def train(model):
opt = torch.optim.Adam(model.parameters(), lr=1e-5)
for _ in range(10000):
opt.zero_grad()
# We train the model to output the value 4 (arbitrarily)
loss = (model(0) - 4)**2
loss.backward()
opt.step()
# Toy model with one parameter tensor of size 3.
# Output is always the sum of the elements in the tensor,
# independent of the input
class Model(nn.Module):
def __init__(self):
super().__init__()
self.x = nn.Parameter(torch.ones(3))
def forward(self, x):
return torch.sum(self.x)
############################################
# Create a separate Model instance and run
# a non-parallel training run.
# For some reason, this code causes the
# subsequent parallel run to fail.
m0 = Model()
train(m0)
print ('Done with preliminary run')
############################################
num_processes = 2
model = Model()
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()
print(model.x)
If you modify your code to create new processes like this:
processes = []
ctx = mp.get_context('spawn')
for rank in range(num_processes):
p = ctx.Process(target=train, args=(model,))
it seems to run fine (rest of code same as yours, tested on pytorch 1.5.0 / python 3.6 / NVIDIA T4 GPU).
I'm not completely sure what is carried over from the non-parallel run to the parallel run; I tried creating a completely new model for the two runs (with its own class), and/or deleting anything from the original, and/or making sure to delete any tensors and free up memory, and none of that made any difference.
What did make a difference was making sure that .backward()
never got called outside of mp.Process()
before it was called by a function within mp.Process()
. I think what may be carried over is an autograd thread; if the thread exists before multiprocessing with the default fork method it fails, if the thread is created after fork it seems to work okay, and if using spawn it also works okay.
Btw: That's a really interesting question - thank you especially for digesting it to a minimal example!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With