In another Q+A (Can I perform dynamic cumsum of rows in pandas?) I made a comment regarding the correctness of using prange
about this code (of this answer):
from numba import njit, prange
@njit
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
The comment was:
I wouldn't recommend parallelizing a loop that isn't pure. In this case the
running
variable makes it impure. There are 4 possible outcomes: (1)numba decides that it cannot parallelize it and just process the loop as if it wascumsum
instead ofprange
(2)it can lift the variable outside the loop and use parallelization on the remainder (3)numba incorrectly inserts synchronization between the parallel executions and the result may be bogus (4)numba inserts the necessary synchronizations around running which may impose more overhead than you gain by parallelizing it in the first place
And the later addition:
Of course both the
running
andcumsum
variable make the loop "impure", not just the running variable as stated in the previous comment
Then I was asked:
This might sound like a silly question, but how can I figure out which of the 4 things it did and improve it? I would really like to become better with numba!
Given that it could be useful for future readers I decided to create a self-answered Q+A here. Spoiler: I cannot really answer the question which of the 4 outcomes is produced (or if numba produces a totally different outcome) so I highly encourage other answers.
TL;DR: First: prange
in identical to range
, except when you add parallel to the jit
, for example njit(parallel=True)
. If you try that you'll see an exception about an "unsupported reduction" - that's because Numba limits the scope of prange
to "pure" loops and "impure loops" with numba-supported reductions and puts the responsibility of making sure that it falls into either of these categories on the user.
This is clearly stated in the documentation of numbas prange
(version 0.42):
1.10.2. Explicit Parallel Loops
Another feature of this code transformation pass is support for explicit parallel loops. One can use Numba’s
prange
instead ofrange
to specify that a loop can be parallelized. The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.
What the comments refer to as "impure" is called "cross iteration dependencies" in that documentation. Such a "cross-iteration dependency" is a variable that changes between loops. A simple example would be:
def func(n):
a = 0
for i in range(n):
a += 1
return a
Here the variable a
depends on the value it had before the loop started and how many iterations of the loop had been executed. That's what is meant by a "cross iteration dependency" or an "impure" loop.
The problem when explicitly parallelizing such a loop is that iterations are performed in parallel but each iteration needs to know what the other iterations are doing. Failure to do so would result in a wrong result.
Let's for a moment assume that prange
would spawn 4 workers and we pass 4
as n
to the function. What would a completely naive implementation do?
Worker 1 starts, gets a i = 1 from `prange`, and reads a = 0
Worker 2 starts, gets a i = 2 from `prange`, and reads a = 0
Worker 3 starts, gets a i = 3 from `prange`, and reads a = 0
Worker 1 executed the loop and sets `a = a + 1` (=> 1)
Worker 3 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 starts, gets a i = 4 from `prange`, and reads a = 2
Worker 2 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 executed the loop and sets `a = a + 1` (=> 3)
=> Loop ended, function return 3
The order in which the different workers read, execute and write to a
can be arbitrary, this was just one example. It could also produce (by accident) the correct result! That's generally called a Race condition.
What would a more sophisticated prange
do that recognizes that there is such a cross iteration dependency?
There are three options:
Given my understanding of the numba documentation (repeated again):
The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.
Numba does:
Unfortunately it's not clear what "supported reductions" are. But the documentation hints that it's binary operators that operate on the previous value in the loop body:
A reduction is inferred automatically if a variable is updated by a binary function/operator using its previous value in the loop body. The initial value of the reduction is inferred automatically for
+=
and*=
operators. For other functions/operators, the reduction variable should hold the identity value right before entering theprange
loop. Reductions in this manner are supported for scalars and for arrays of arbitrary dimensions.
The code in the OP uses a list as cross iteration dependency and calls list.append
in the loop body. Personally I wouldn't call list.append
a reduction and it's not using a binary operator so my assumption would be that it's very likely not supported. As for the other cross iteration dependency running
: It's using addition on the result of the previous iteration (which would be fine) but also conditionally resets it to zero if it exceeds a threshold (which is probably not fine).
Numba provides ways to inspect the intermediate code (LLVM and ASM) code:
dynamic_cumsum.inspect_types()
dynamic_cumsum.inspect_llvm()
dynamic_cumsum.inspect_asm()
But even if I had the required understanding of the results to make any statement about the correctness of the emitted code - in general it's highly nontrivial to "prove" that multi-threaded/process code works correctly. Given that I even lack the LLVM and ASM knowledge to even see if it even tries to parallelize it I cannot actually answer your specific question which outcome it produces.
Back to the code, as mentioned it throws an exception (unsupported reduction) if I use parallel=True
, so I assume that numba doesn't parallelize anything in the example:
from numba import njit, prange
@njit(parallel=True)
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
dynamic_cumsum(np.ones(100), np.arange(100), 10)
AssertionError: Invalid reduction format During handling of the above exception, another exception occurred: LoweringError: Failed in nopython mode pipeline (step: nopython mode backend) Invalid reduction format File "<>", line 7: def dynamic_cumsum(seq, index, max_value): <source elided> running = 0 for i in prange(len(seq)): ^ [1] During: lowering "id=2[LoopNest(index_variable = parfor_index.192, range = (0, seq_size0.189, 1))]{56: <ir.Block at <> (10)>, 24: <ir.Block at <> (7)>, 34: <ir.Block at <> (8)>}Var(parfor_index.192, <> (7))" at <> (7)
So what is left to say: prange
does not provide any speed advantage in this case over a normal range
(because it's not executing in parallel). So in that case I would not "risk" potential problems and/or confusing the readers - given that it's not supported according to the numba documentation.
from numba import njit, prange
@njit
def p_dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
@njit
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in range(len(seq)): # <-- here is the only change
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
Just a quick timing that supports the "not faster than" statement I made earlier:
import numpy as np
seq = np.random.randint(0, 100, 10_000_000)
index = np.arange(10_000_000)
max_ = 500
# Correctness and warm-up
assert p_dynamic_cumsum(seq, index, max_) == dynamic_cumsum(seq, index, max_)
%timeit p_dynamic_cumsum(seq, index, max_)
# 468 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dynamic_cumsum(seq, index, max_)
# 470 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With