Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast alternative for numpy.median.reduceat

Relating to this answer, is there a fast way to compute medians over an array that has groups with an unequal number of elements?

E.g.:

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67, ... ]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3,    ... ]

And then I want to compute the difference between the number and the median per group (e.g. median of group 0 is 1.025 so the first result is 1.00 - 1.025 = -0.025). So for the array above, the results would appear as:

result = [-0.025, 0.025, 0.05, -0.05, -0.19, 0.29, 0.00, 0.10, -0.10, ...]

Since np.median.reduceat doesn't exist (yet), is there another fast way to achieve this? My array will contain millions of rows so speed is crucial!

Indices can be assumed to be contiguous and ordered (it's easy to transform them if they aren't).


Example data for performance comparisons:

import numpy as np

np.random.seed(0)
rows = 10000
cols = 500
ngroup = 100

# Create random data and groups (unique per column)
data = np.random.rand(rows,cols)
groups = np.random.randint(ngroup, size=(rows,cols)) + 10*np.tile(np.arange(cols),(rows,1))

# Flatten
data = data.ravel()
groups = groups.ravel()

# Sort by group
idx_sort = groups.argsort()
data = data[idx_sort]
groups = groups[idx_sort]
like image 449
Jean-Paul Avatar asked Nov 10 '19 11:11

Jean-Paul


People also ask

Is there a median function in NumPy?

NumPy median computes the median of the values in a NumPy array. The NumPy median function computes the median of the values in a NumPy array. Note that the NumPy median function will also operate on “array-like objects” like Python lists.

How do you find the median of a vector in Python?

If the number of values, n, is odd, then the median is the value in the (n+1)/2 position in the sorted list(or array) of values. If the number of values, n, is even, then the median is the average of the values in n/2 and n/2 + 1 position in the sorted list(or array) of values.


Video Answer


1 Answers

Sometimes you need to write non-idiomatic numpy code if you really want to speed up your calculation which you can't do with native numpy.

numba compiles your python code to low-level C. Since a lot of numpy itself is usually as fast as C, this mostly ends up being useful if your problem doesn't lend itself to native vectorization with numpy. This is one example (where I assumed that the indices are contiguous and sorted, which is also reflected in the example data):

import numpy as np
import numba

# use the inflated example of roganjosh https://stackoverflow.com/a/58788534
data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3] 

data = np.array(data * 500) # using arrays is important for numba!
index = np.sort(np.random.randint(0, 30, 4500))               

# jit-decorate; original is available as .py_func attribute
@numba.njit('f8[:](f8[:], i8[:])') # explicit signature implies ahead-of-time compile
def diffmedian_jit(data, index): 
    res = np.empty_like(data) 
    i_start = 0 
    for i in range(1, index.size): 
        if index[i] == index[i_start]: 
            continue 

        # here: i is the first _next_ index 
        inds = slice(i_start, i)  # i_start:i slice 
        res[inds] = data[inds] - np.median(data[inds]) 

        i_start = i 

    # also fix last label 
    res[i_start:] = data[i_start:] - np.median(data[i_start:])

    return res

And here are some timings using IPython's %timeit magic:

>>> %timeit diffmedian_jit.py_func(data, index)  # non-jitted function
... %timeit diffmedian_jit(data, index)  # jitted function
...
4.27 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
65.2 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Using the updated example data in the question these numbers (i.e. the runtime of the python function vs. the runtime of the JIT-accelerated functio) are

>>> %timeit diffmedian_jit.py_func(data, groups) 
... %timeit diffmedian_jit(data, groups)
2.45 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93.6 ms ± 518 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This amounts to a 65x speedup in the smaller case and a 26x speedup in the larger case (compared to slow loopy code, of course) using the accelerated code. Another upside is that (unlike typical vectorization with native numpy) we didn't need additional memory to achieve this speed, it's all about optimized and compiled low-level code that ends up being run.


The above function assumes that numpy int arrays are int64 by default, which is not actually the case on Windows. So an alternative is to remove the signature from the call to numba.njit, triggering proper just-in-time compilation. But this means that the function will be compiled during the first execution, which can meddle with timing results (we can either execute the function once manually, using representative data types, or just accept that the first timing execution will be much slower, which should be ignored). This is exactly what I tried to prevent by specifying a signature, which triggers ahead-of-time compilation.

Anyway, in the properly JIT case the decorator we need is just

@numba.njit
def diffmedian_jit(...):

Note that the above timings I showed for the jit-compiled function only apply once the function had been compiled. This either happens at definition (with eager compilation, when an explicit signature is passed to numba.njit), or during the first function call (with lazy compilation, when no signature is passed to numba.njit). If the function is only going to be executed once then the compile time should also be considered for the speed of this method. It's typically only worth compiling functions if the total time of compilation + execution is less than the uncompiled runtime (which is actually true in the above case, where the native python function is very slow). This mostly happens when you are calling your compiled function a lot of times.

As max9111 noted in a comment, one important feature of numba is the cache keyword to jit. Passing cache=True to numba.jit will store the compiled function to disk, so that during the next execution of the given python module the function will be loaded from there rather than recompiled, which again can spare you runtime in the long run.

like image 93