Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to parallelise python input pipeline in Distributed Tensorflow

I have a non trivial input pipeline, which consists of reading ground truth and raw data and performing preprocessing on them, written in Python. It takes a long time to run the input pipeline for a single sample so I have multiple processes (from python multiprocessing package) running in parallel and queues to perform the operation quickly and prefetch data. The output is then fed to my network using feed_dict. The overhead of this process in my training loop is 2 orders of magnitude less than the actual tf.Session.run() time. I'm trying to move to the tf.data API, by wrapping with tf.py_func my read+preprocess functions but it runs slowly, probably due to GIL, even when increasing the number of multiple calls. I want to scale up my training to multiple machines and am not sure how data fetching behaves in such a case, also there's the performance issue for a single machine as well :)

So, basically my question is: How to run python functions in tf.data api input pipeline in parallel on multiple CPU cores?

like image 258
Guy Zohar Avatar asked Nov 08 '22 07:11

Guy Zohar


1 Answers

A couple of clarifications, tf.py_func can run in parallel with your sess.run() (because sess.run() releases the GIL) but you cannot run multiple tf.py_func in the same python process.

The usual answer in such cases is to do the pre-processing once offline, save the results on disk (e.g. using TFRecord format), read ready data from files during training. You can probably parallelize the offline preprocessing using something like multiprocessing.

If you can express your pre-processing using tf operations, you can run it in parallel using Dataset.map, but there is no built-in support for python multiprocessing in tf.data. If the above does not work for some reason, you would probably have to hook up multiprocessing yourself.

One way to approach this is the following. Have multiple processes produce your inputs, put them into multiprocessing.Queue (or shared memory with some locking around it). Implement the receiving side using a generator function and create a dataset using from_generator.

like image 107
iga Avatar answered Nov 14 '22 21:11

iga