Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use PyTorch multiprocessing?

I'm trying to use python's multiprocessing Pool method in pytorch to process a image. Here's the code:

from multiprocessing import Process, Pool
from torch.autograd import Variable
import numpy as np
from scipy.ndimage import zoom

def get_pred(args):

  img = args[0]
  scale = args[1]
  scales = args[2]
  img_scale = zoom(img.numpy(),
                     (1., 1., scale, scale),
                     order=1,
                     prefilter=False,
                     mode='nearest')

  # feed input data
  input_img = Variable(torch.from_numpy(img_scale),
                     volatile=True).cuda()
  return input_img

scales = [1,2,3,4,5]
scale_list = []
for scale in scales: 
    scale_list.append([img,scale,scales])
multi_pool = Pool(processes=5)
predictions = multi_pool.map(get_pred,scale_list)
multi_pool.close() 
multi_pool.join()

I'm getting this error:

`RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

` In this line:

predictions = multi_pool.map(get_pred,scale_list)

Can anyone tell me what I'm doing wrong ?

like image 451
Rahul Avatar asked Feb 16 '18 08:02

Rahul


2 Answers

As stated in pytorch documentation the best practice to handle multiprocessing is to use torch.multiprocessing instead of multiprocessing.

Be aware that sharing CUDA tensors between processes is supported only in Python 3, either with spawn or forkserver as start method.

Without touching your code, a workaround for the error you got is replacing

from multiprocessing import Process, Pool

with:

from torch.multiprocessing import Pool, Process, set_start_method
try:
     set_start_method('spawn')
except RuntimeError:
    pass
like image 101
nicobonne Avatar answered Nov 14 '22 15:11

nicobonne


I suggest you read the docs for the multiprocessing module, especially this section. You will have to change the way subprocesses are created by calling set_start_method. Taken from those quoted docs:

import multiprocessing as mp

def foo(q):
    q.put('hello')

if __name__ == '__main__':
    mp.set_start_method('spawn')
    q = mp.Queue()
    p = mp.Process(target=foo, args=(q,))
    p.start()
    print(q.get())
    p.join()
like image 7
Oliver Baumann Avatar answered Nov 14 '22 16:11

Oliver Baumann