Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numba doesn't parallelize range

Tags:

python

jit

numba

I have loops in my code that I want to parallelize

from numba import njit, prange
from time import time


@njit
def f1(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit
def f2(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit(parallel=True)
def f3(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit(parallel=True)
def f4(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


for f in [f1, f2, f3, f4]:
    d = time()
    f(2500)
    print('%.02f' % (time() - d))

I get the times:

27.44
27.34
26.83
13.05

I checked the activity of my processor, and while the first three functions were at 100%, the fourth was at ~300%.

I don't understand why specifying parallel didn't change anything and one needs to use prange. In the doc, there is an example with range.

like image 413
Labo Avatar asked Nov 08 '22 07:11

Labo


1 Answers

From the Numba documentation:

The experimental parallel=True option to @jit will attempt to optimize array operations and run them in parallel. It also adds support for prange() to explicitly parallelize a loop.

Now since you do not do any array operations in your function, there is nothing Numba can parallelize without explicitly marking the loops with prange.

So just to be sure there is no confusion. Numba will only split your loop into threads when you set parallel=True in the decoration, and explictly marks the loops by changeing; range -> prange.

In your f4() you have put prange on all of the for loops, I would recommend only putting the prange on the outer most loop, because you don't want to risk spawning threads from threads. I.e.:

@njit(parallel=True)
def f5(n):
    s = 0
    for i in prange(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s
like image 151
Erik Kjellgren Avatar answered Nov 15 '22 13:11

Erik Kjellgren