I am trying to write my data (from a single file in hdf5 format) to multiple files, and it works fine when the task is executed in serial. Now I want to improve the efficiency and modify the code using the multiprocessing
module, but the output sometimes go wrong. Here's a simplified version of my code.
import multiprocessing as mp
import numpy as np
import math, h5py, time
N = 4 # number of processes to use
block_size = 300
data_sz = 678
dataFile = 'mydata.h5'
# fake some data
mydata = np.zeros((data_sz, 1))
for i in range(data_sz):
mydata[i, 0] = i+1
h5file = h5py.File(dataFile, 'w')
h5file.create_dataset('train', data=mydata)
# fire multiple workers
pool = mp.Pool(processes=N)
total_part = int(math.ceil(1. * data_sz / block_size))
for i in range(total_part):
pool.apply_async(data_write_func, args=(dataFile, i, ))
pool.close()
pool.join()
and the data_write_func()
's structure is:
def data_write_func(h5file_dir, i, block_size=block_size):
hf = h5py.File(h5file_dir)
fout = open('data_part_' + str(i), 'w')
data_part = hf['train'][block_size*i : min(block_size*(i+1), data_sz)] # np.ndarray
for line in data_part:
# do some processing, that takes a while...
time.sleep(0.01)
# then write out..
fout.write(str(line[0]) + '\n')
fout.close()
when I set N=1
, it works well. but when I set N=2
or N=4
, the result get messed sometimes(not every time!). e.g. in data_part_1 I expect the output to be:
301,
302,
303,
...
But sometimes what I get is
0,
0,
0,
...
sometimes I get
379,
380,
381,
...
I'm new to the multiprocessing module, and find it tricky. Appreciate it if any suggestions!
After fixing the fout.write
and mydata=...
as Andriy suggested your program works as intended, because every process writes to his own file. There's no way the processes intermingle with each other.
What you probaby wanted to do is using multiprocessing.map()
which cuts your iterable for you (so you don't need to do the block_size
thingies), plus it guarantees that the results are done in order. I've reworked your code to use multiprocessing map:
import multiprocessing
from functools import partial
import pprint
def data_write_func(line):
i = multiprocessing.current_process()._identity[0]
line = [i*2 for i in line]
files[i-1].write(",".join((str(s) for s in line)) + "\n")
N = 4
mydata=[[x+1,x+2,x+3,x+4] for x in range(0,4000*N,4)] # fake some data
files = [open('data_part_'+str(i), 'w') for i in range(N)]
pool = multiprocessing.Pool(processes=N)
pool.map(data_write_func, mydata)
pool.close()
pool.join()
Please note:
data_write_func
is called for every row, the file opening needs to be done in the parent process. Also: you don't need to do the close()
the file manually, the OS will do that for you on exit of your python program.Now, I guess in the end you'd want to have all the output in one file, not in separate files. If your output line is below 4096 bytes on linux (or below 512 bytes on OSX, for other OSes see here) you're actually safe to just open one file (in append mode) and let every process just write into that one file, as writes below these sizes are guaranteed to be atomic by Unix.
Update:
"What if the data is stored in hdf5 file as dataset?"
According to hdf5 doc this works out of the box since version 2.2.0:
Parallel HDF5 is a configuration of the HDF5 library which lets you share open files across multiple parallel processes. It uses the MPI (Message Passing Interface) standard for interprocess communication
So if you do this in your code:
h5file = h5py.File(dataFile, 'w')
dset = h5file.create_dataset('train', data=mydata)
Then you can just access dset from within your process and read/write to it without taking any extra measures. See also this example from h5py using multiprocessing
The issue could not be replicated. Here is my full code:
#!/usr/bin/env python
import multiprocessing
N = 4
mydata=[[x+1,x+2,x+3,x+4] for x in range(0,4000*N,4)] # fake some data
def data_write_func(mydata, i, block_size=1000):
fout = open('data_part_'+str(i), 'w')
data_part = mydata[block_size*i: block_size*i+block_size]
for line in data_part:
# do some processing, say *2 for each element...
line = [x*2 for x in line]
# then write out..
fout.write(','.join(map(str,line))+'\n')
fout.close()
pool = multiprocessing.Pool(processes=N)
for i in range(2):
pool.apply_async(data_write_func, (mydata, i, ))
pool.close()
pool.join()
Sample output from data_part_0
:
2,4,6,8
10,12,14,16
18,20,22,24
26,28,30,32
34,36,38,40
42,44,46,48
50,52,54,56
58,60,62,64
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With