Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to determine if numba's prange actually works correctly?

In another Q+A (Can I perform dynamic cumsum of rows in pandas?) I made a comment regarding the correctness of using prange about this code (of this answer):

from numba import njit, prange

@njit
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

The comment was:

I wouldn't recommend parallelizing a loop that isn't pure. In this case the running variable makes it impure. There are 4 possible outcomes: (1)numba decides that it cannot parallelize it and just process the loop as if it was cumsum instead of prange (2)it can lift the variable outside the loop and use parallelization on the remainder (3)numba incorrectly inserts synchronization between the parallel executions and the result may be bogus (4)numba inserts the necessary synchronizations around running which may impose more overhead than you gain by parallelizing it in the first place

And the later addition:

Of course both the running and cumsum variable make the loop "impure", not just the running variable as stated in the previous comment

Then I was asked:

This might sound like a silly question, but how can I figure out which of the 4 things it did and improve it? I would really like to become better with numba!

Given that it could be useful for future readers I decided to create a self-answered Q+A here. Spoiler: I cannot really answer the question which of the 4 outcomes is produced (or if numba produces a totally different outcome) so I highly encourage other answers.

like image 585
MSeifert Avatar asked Feb 07 '19 22:02

MSeifert


1 Answers

TL;DR: First: prange in identical to range, except when you add parallel to the jit, for example njit(parallel=True). If you try that you'll see an exception about an "unsupported reduction" - that's because Numba limits the scope of prange to "pure" loops and "impure loops" with numba-supported reductions and puts the responsibility of making sure that it falls into either of these categories on the user.

This is clearly stated in the documentation of numbas prange (version 0.42):

1.10.2. Explicit Parallel Loops

Another feature of this code transformation pass is support for explicit parallel loops. One can use Numba’s prange instead of range to specify that a loop can be parallelized. The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.

What the comments refer to as "impure" is called "cross iteration dependencies" in that documentation. Such a "cross-iteration dependency" is a variable that changes between loops. A simple example would be:

def func(n):
    a = 0
    for i in range(n):
        a += 1
    return a

Here the variable a depends on the value it had before the loop started and how many iterations of the loop had been executed. That's what is meant by a "cross iteration dependency" or an "impure" loop.

The problem when explicitly parallelizing such a loop is that iterations are performed in parallel but each iteration needs to know what the other iterations are doing. Failure to do so would result in a wrong result.

Let's for a moment assume that prange would spawn 4 workers and we pass 4 as n to the function. What would a completely naive implementation do?

Worker 1 starts, gets a i = 1 from `prange`, and reads a = 0
Worker 2 starts, gets a i = 2 from `prange`, and reads a = 0
Worker 3 starts, gets a i = 3 from `prange`, and reads a = 0
Worker 1 executed the loop and sets `a = a + 1` (=> 1)
Worker 3 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 starts, gets a i = 4 from `prange`, and reads a = 2
Worker 2 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 executed the loop and sets `a = a + 1` (=> 3)

=> Loop ended, function return 3

The order in which the different workers read, execute and write to a can be arbitrary, this was just one example. It could also produce (by accident) the correct result! That's generally called a Race condition.

What would a more sophisticated prange do that recognizes that there is such a cross iteration dependency?

There are three options:

  • Simply don't parallelize it.
  • Implement a mechanism where the workers share the variable. Typical examples here are Locks (this can incur a high overhead).
  • Recognize that it's a reduction that can be parallelized.

Given my understanding of the numba documentation (repeated again):

The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.

Numba does:

  • If it's a known reduction then use patterns to parallelize it
  • If it's not a known reduction throw an exception

Unfortunately it's not clear what "supported reductions" are. But the documentation hints that it's binary operators that operate on the previous value in the loop body:

A reduction is inferred automatically if a variable is updated by a binary function/operator using its previous value in the loop body. The initial value of the reduction is inferred automatically for += and *= operators. For other functions/operators, the reduction variable should hold the identity value right before entering the prange loop. Reductions in this manner are supported for scalars and for arrays of arbitrary dimensions.

The code in the OP uses a list as cross iteration dependency and calls list.append in the loop body. Personally I wouldn't call list.append a reduction and it's not using a binary operator so my assumption would be that it's very likely not supported. As for the other cross iteration dependency running: It's using addition on the result of the previous iteration (which would be fine) but also conditionally resets it to zero if it exceeds a threshold (which is probably not fine).

Numba provides ways to inspect the intermediate code (LLVM and ASM) code:

dynamic_cumsum.inspect_types()
dynamic_cumsum.inspect_llvm()
dynamic_cumsum.inspect_asm()

But even if I had the required understanding of the results to make any statement about the correctness of the emitted code - in general it's highly nontrivial to "prove" that multi-threaded/process code works correctly. Given that I even lack the LLVM and ASM knowledge to even see if it even tries to parallelize it I cannot actually answer your specific question which outcome it produces.

Back to the code, as mentioned it throws an exception (unsupported reduction) if I use parallel=True, so I assume that numba doesn't parallelize anything in the example:

from numba import njit, prange

@njit(parallel=True)
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

dynamic_cumsum(np.ones(100), np.arange(100), 10)
AssertionError: Invalid reduction format

During handling of the above exception, another exception occurred:

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Invalid reduction format

File "<>", line 7:
def dynamic_cumsum(seq, index, max_value):
    <source elided>
    running = 0
    for i in prange(len(seq)):
    ^

[1] During: lowering "id=2[LoopNest(index_variable = parfor_index.192, range = (0, seq_size0.189, 1))]{56: <ir.Block at <> (10)>, 24: <ir.Block at <> (7)>, 34: <ir.Block at <> (8)>}Var(parfor_index.192, <> (7))" at <> (7)

So what is left to say: prange does not provide any speed advantage in this case over a normal range (because it's not executing in parallel). So in that case I would not "risk" potential problems and/or confusing the readers - given that it's not supported according to the numba documentation.

from numba import njit, prange

@njit
def p_dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in prange(len(seq)):
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

@njit
def dynamic_cumsum(seq, index, max_value):
    cumsum = []
    running = 0
    for i in range(len(seq)):  # <-- here is the only change
        if running > max_value:
            cumsum.append([index[i], running])
            running = 0
        running += seq[i] 
    cumsum.append([index[-1], running])

    return cumsum

Just a quick timing that supports the "not faster than" statement I made earlier:

import numpy as np
seq = np.random.randint(0, 100, 10_000_000)
index = np.arange(10_000_000)
max_ = 500
# Correctness and warm-up
assert p_dynamic_cumsum(seq, index, max_) == dynamic_cumsum(seq, index, max_)
%timeit p_dynamic_cumsum(seq, index, max_)
# 468 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dynamic_cumsum(seq, index, max_)
# 470 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
like image 72
MSeifert Avatar answered Nov 14 '22 22:11

MSeifert