Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sample from each group in polars dataframe?

I'm looking for a function along the lines of

df.group_by('column').agg(sample(10))

so that I can take ten or so randomly-selected elements from each group.

This is specifically so I can read in a LazyFrame and work with a small sample of each group as opposed to the entire dataframe.

Update:

One approximate solution is:

df = lf.group_by('column').agg(
        pl.all().sample(.001)
    )
df = df.explode(df.columns[1:])

Update 2

That approximate solution is just the same as sampling the whole dataframe and doing a groupby after. No good.


2 Answers

Let start with some dummy data:

n = 100
seed = 0

df = pl.DataFrame({
    "groups": (pl.int_range(n, eager=True) % 5).shuffle(seed=seed),
    "values": pl.int_range(n, eager=True).shuffle(seed=seed)
})
shape: (100, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 55     │
│ 0      ┆ 40     │
│ 2      ┆ 57     │
│ 4      ┆ 99     │
│ 4      ┆ 4      │
│ …      ┆ …      │
│ 0      ┆ 90     │
│ 2      ┆ 87     │
│ 1      ┆ 96     │
│ 3      ┆ 43     │
│ 4      ┆ 44     │
└────────┴────────┘

This gives us 100 / 5, is 5 groups of 20 elements. Let's verify that:

df.group_by("groups").agg(pl.len())
shape: (5, 2)
┌────────┬─────┐
│ groups ┆ len │
│ ---    ┆ --- │
│ i64    ┆ u32 │
╞════════╪═════╡
│ 0      ┆ 20  │
│ 4      ┆ 20  │
│ 2      ┆ 20  │
│ 3      ┆ 20  │
│ 1      ┆ 20  │
└────────┴─────┘

Sample our data

Now we are going to use a window function to take a sample of our data.

df.filter(
    pl.int_range(pl.len()).shuffle().over("groups") < 10
)
shape: (50, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 55     │
│ 2      ┆ 57     │
│ 4      ┆ 99     │
│ 4      ┆ 4      │
│ 1      ┆ 81     │
│ …      ┆ …      │
│ 2      ┆ 22     │
│ 1      ┆ 76     │
│ 3      ┆ 98     │
│ 0      ┆ 90     │
│ 4      ┆ 44     │
└────────┴────────┘

For every group in over("group") the pl.int_range(pl.len()) expression creates an index row. We then shuffle that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask that we can pass to the filter method.

like image 157
ritchie46 Avatar answered Jan 28 '26 04:01

ritchie46


This worked better for me:

sampled_df = pl.concat(
    df.sample(fraction=0.001) for df in 
    df.partition_by(["column"], include_key=True)
)

The problem with .agg(pl.col("column").sample(2) was that it seemed to select different values for each column. What I needed was randomly selected rows.

like image 26
santon Avatar answered Jan 28 '26 03:01

santon



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!