Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Polars map_batches UDF with Multi-processing

I want to apply a numba UDF, which generates the same length vectors for each groups in df:

import numba

df = pl.DataFrame(
    {
        "group": ["A", "A", "A", "B", "B"],
        "index": [1, 3, 5, 1, 4],
    }
)

@numba.jit(nopython=True)
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0
    
    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0
            
    return result

df.with_columns(
    pl.col("index")
    .map_batches(
        lambda x: UDF(x.to_numpy(), 5)
        )
    .over("group")
    .cast(pl.UInt8)
    .alias("udf")
    )

Inspired by this post where a multi-processing application has being introduced. However, in the case above, I am applying the UDF using a over window function. Is there an efficient approach by parallelizing the above executions?

expected output:

shape: (6, 3)
┌───────┬───────┬─────┐
│ group ┆ index ┆ udf │
│ ---   ┆ ---   ┆ --- │
│ str   ┆ i64   ┆ u8  │
╞═══════╪═══════╪═════╡
│ A     ┆ 1     ┆ 0   │
│ A     ┆ 3     ┆ 0   │
│ A     ┆ 5     ┆ 1   │
│ B     ┆ 1     ┆ 0   │
│ B     ┆ 4     ┆ 1   │
└───────┴───────┴─────┘
like image 202
Kevin Li Avatar asked Oct 24 '25 15:10

Kevin Li


1 Answers

Here is example how you can do this with numba + using numba's parallelization features:

from numba import njit, prange


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)
print(df)

Prints:

shape: (9, 3)
┌───────┬───────┬─────────┐
│ group ┆ index ┆ new_udf │
│ ---   ┆ ---   ┆ ---     │
│ str   ┆ i64   ┆ u8      │
╞═══════╪═══════╪═════════╡
│ A     ┆ 1     ┆ 0       │
│ A     ┆ 3     ┆ 0       │
│ A     ┆ 5     ┆ 1       │
│ B     ┆ 1     ┆ 0       │
│ B     ┆ 4     ┆ 1       │
│ B     ┆ 8     ┆ 1       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 4     ┆ 1       │
└───────┴───────┴─────────┘

Benchmark:

from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N, n):
    assert N % n == 0

    df = pl.DataFrame(
        {
            "group": [f"group_{i}" for i in range(N // n) for _ in range(n)],
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


df = get_df(3 * 33_333, 3)  # 100_000 values, length of groups 3

df = get_udf_polars(df)

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit(
    'df.with_columns(pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5)))',
    number=1,
    globals=globals(),
)

print(t1)
print(t2)

Prints on my machine (AMD 5700x):

2.7000599699968006
0.00025866299984045327

100_000_000 rows/groups 3 takes 0.06319052699836902 (with parallel=False this takes 0.2159650030080229)


EDIT: Handling variable-length groups:

@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )

df = get_udf_polars_nb(df)

Benchmark:

import random
from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N):
    groups = []
    cnt, group_no, running = 0, 1, True
    while running:
        for _ in range(random.randint(3, 10)):
            groups.append(group_no)
            cnt += 1
            if cnt >= N:
                running = False
                break
        group_no += 1

    df = pl.DataFrame(
        {
            "group": groups,
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )


df = get_df(100_000)  # 100_000 values, length of groups length 3-9

df = get_udf_polars(df)
df = get_udf_polars_nb(df)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit("get_udf_polars_nb(df)", number=1, globals=globals())

print(t1)
print(t2)

Prints:

1.2675148629932664
0.0024339070077985525
like image 180
Andrej Kesely Avatar answered Oct 26 '25 04:10

Andrej Kesely



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!