Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What batch_size and pre_dispatch in joblib exactly mean

From documentation here https://pythonhosted.org/joblib/parallel.html#parallel-reference-documentation It's not clear for me what exactly batch_size and pre_dispatch means.

Let's consider case when we are using 'multiprocessing' backend, 2 jobs (2 processes) and we have 10 tasks to compute.

As i understand:

batch_size - controls amount of pickled tasks at one time, so if you set batch_size = 5 - joblib will pickle and send 5 tasks immediately to each process, and after arriving there they will be solved by process sequentially, one after another. With batch_size=1 joblib will pickle and send one task at a time, if and only if that process completed previous task.

To show what i mean:

def solve_one_task(task):
    # Solves one task at a time
    ....
    return result

def solve_list(list_of_tasks):
    # Solves batch of tasks sequentially
    return [solve_one_task(task) for task in list_of_tasks]

So this code:

Parallel(n_jobs=2, backend = 'multiprocessing', batch_size=5)(
        delayed(solve_one_task)(task) for task in tasks)

is equal to this code (in perfomance):

slices = [(0,5)(5,10)]
Parallel(n_jobs=2, backend = 'multiprocessing', batch_size=1)(
        delayed(solve_list)(tasks[slice[0]:slice[1]]) for slice in slices)

Am i right? And what pre_dispatch means then?

like image 933
Ibraim Ganiev Avatar asked Nov 14 '15 23:11

Ibraim Ganiev


1 Answers

As it turns out, i was right, and two sections of code are pretty similar in perfomance sense, so batch_size works as i expected in Question. pre_dispatch (as documentation states) controls number of instantiated tasks in task queue.

from sklearn.externals.joblib import Parallel, delayed
from time import sleep, time

def solve_one_task(task):
    # Solves one task at a time
    print("%d. Task #%d is being solved"%(time(), task))
    sleep(5)
    return task

def task_gen(max_task):
    current_task = 0
    while current_task < max_task:
        print("%d. Task #%d was dispatched"%(time(), current_task))
        yield current_task
        current_task += 1

Parallel(n_jobs=2, backend = 'multiprocessing', batch_size=1, pre_dispatch=3)(
        delayed(solve_one_task)(task) for task in task_gen(10))

outputs:

1450105367. Task #0 was dispatched
1450105367. Task #1 was dispatched
1450105367. Task #2 was dispatched
1450105367. Task #0 is being solved
1450105367. Task #1 is being solved
1450105372. Task #2 is being solved
1450105372. Task #3 was dispatched
1450105372. Task #4 was dispatched
1450105372. Task #3 is being solved
1450105377. Task #4 is being solved
1450105377. Task #5 was dispatched
1450105377. Task #5 is being solved
1450105377. Task #6 was dispatched
1450105382. Task #7 was dispatched
1450105382. Task #6 is being solved
1450105382. Task #7 is being solved
1450105382. Task #8 was dispatched
1450105387. Task #9 was dispatched
1450105387. Task #8 is being solved
1450105387. Task #9 is being solved
Out[1]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
like image 181
Ibraim Ganiev Avatar answered Nov 12 '22 04:11

Ibraim Ganiev