Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Starmap combined with tqdm?

I am doing some parallel processing, as follows:

with mp.Pool(8) as tmpPool:
        results = tmpPool.starmap(my_function, inputs)

where inputs look like: [(1,0.2312),(5,0.52) ...] i.e., tuples of an int and a float.

The code runs nicely, yet I cannot seem to wrap it around a loading bar (tqdm), such as can be done with e.g., imap method as follows:

tqdm.tqdm(mp.imap(some_function,some_inputs))

Can this be done for starmap also?

Thanks!

like image 718
sdgaw erzswer Avatar asked Aug 05 '19 08:08

sdgaw erzswer


People also ask

Can you use tqdm with multiprocessing?

tqdm(range(0, 30)) does not work with multiprocessing (as formulated in the code below).

Is Imap_unordered faster?

That is, if you have operations that can take very different amounts of time (rather than the consistent 0.01 seconds you were using in your example), imap_unordered can smooth things out by yielding faster-calculated values ahead of slower-calculated values.

How fast is tqdm?

tqdm is the default iterator. It takes an iterator object as argument and displays a progress bar as it iterates over it. You can see the nice output with 9.90it/s meaning an average speed of 9.90 iterations per second.

What is Chunksize in multiprocessing?

It is the single execution of the function specified with the func -parameter of a Pool -method, called with arguments obtained from a single element of the transmitted chunk. A task consists of chunksize taskels.


3 Answers

It's not possible with starmap(), but it's possible with a patch adding Pool.istarmap(). It's based on the code for imap(). All you have to do, is create the istarmap.py-file and import the module to apply the patch before you make your regular multiprocessing-imports.

Python <3.8

# istarmap.py for Python <3.8
import multiprocessing.pool as mpp


def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    if self._state != mpp.RUN:
        raise ValueError("Pool not running")

    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self._cache)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)


mpp.Pool.istarmap = istarmap

Python 3.8+

# istarmap.py for Python 3.8+
import multiprocessing.pool as mpp


def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    self._check_running()
    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)


mpp.Pool.istarmap = istarmap

Then in your script:

import istarmap  # import to apply patch
from multiprocessing import Pool
import tqdm    


def foo(a, b):
    for _ in range(int(50e6)):
        pass
    return a, b    


if __name__ == '__main__':

    with Pool(4) as pool:
        iterable = [(i, 'x') for i in range(10)]
        for _ in tqdm.tqdm(pool.istarmap(foo, iterable),
                           total=len(iterable)):
            pass
like image 173
Darkonaut Avatar answered Oct 20 '22 07:10

Darkonaut


The simplest way would probably be to apply tqdm() around the inputs, rather than the mapping function. For example:

inputs = zip(param1, param2, param3)
with mp.Pool(8) as pool:
    results = pool.starmap(my_function, tqdm.tqdm(inputs, total=len(param1)))
like image 43
corey Avatar answered Oct 20 '22 07:10

corey


As Darkonaut mentioned, as of this writing there's no istarmap natively available. If you want to avoid patching, you can add a simple *_star function as a workaround. (This solution inspired by this tutorial.)

import tqdm
import multiprocessing

def my_function(arg1, arg2, arg3):
  return arg1 + arg2 + arg3

def my_function_star(args):
    return my_function(*args)

jobs = 4
with multiprocessing.Pool(jobs) as pool:
    args = [(i, i, i) for i in range(10000)]
    results = list(tqdm.tqdm(pool.imap(my_function_star, args), total=len(args))

Some notes:

I also really like corey's answer. It's cleaner, though the progress bar does not appear to update as smoothly as my answer. Note that corey's answer is several orders of magnitude faster with the code I posted above with chunksize=1 (default). I'm guessing this is due to multiprocessing serialization, because increasing chunksize (or having a more expensive my_function) makes their runtime comparable.

I went with my answer for my application since my serialization/function cost ratio was very low.

like image 10
cydonian Avatar answered Oct 20 '22 07:10

cydonian