Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient way to get several subsets of list elements?

I have a DataFrame like this:

import polars as pl

# increase list repr defaults
pl.Config(fmt_table_cell_list_len=10, fmt_str_lengths=100)

df = pl.DataFrame(
    {
        "grp": ["a", "b"],
        "val": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
    }
)
df
shape: (2, 2)
┌─────┬─────────────────────────────────┐
│ grp ┆ val                             │
│ --- ┆ ---                             │
│ str ┆ list[i64]                       │
╞═════╪═════════════════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]                 │
│ b   ┆ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] │
└─────┴─────────────────────────────────┘

I want to select elements in the val column based on this pattern:

  • take the first 2 values
  • take the last 2 values
  • take a sample of 2 values in the "middle" (meaning the remaining set of values excluding the first and last 2 values)

From those three selections, keep unique elements.

This means that when there are 6 or fewer values (as in the first row) then all values are returned, but otherwise (as in the second row) only a subset of 6 values will be returned.

Therefore, the desired output would look like this:

shape: (2, 2)
┌─────┬─────────────────────┐
│ grp ┆ val                 │
│ --- ┆ ---                 │
│ str ┆ list[i64]           │
╞═════╪═════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]     │ 
│ b   ┆ [1, 2, 4, 7, 9, 10] │  # <<<< 4 and 7 are the two randomly selected values in the "middle" set
└─────┴─────────────────────┘

To select the first two and last two values, I can use list.head() and list.tail(). For the random pick in the remaining values, I thought I could do list.set_difference() to remove the first and last two values, and then list.sample(). However, list.sample() fails because in the first row, there's only one value left after removing the first and last two, and I ask for two values:

(
    df.select(
        head=pl.col("val").list.head(2),

        middle=pl.col("val")
        .list.set_difference(pl.col("val").list.head(2))
        .list.set_difference(pl.col("val").list.tail(2))
        .list.sample(2, seed=1234),
        
        tail=pl.col("val").list.tail(2),
    ).select(concat=pl.concat_list(["head", "middle", "tail"]).list.unique())
)
ShapeError: cannot take a larger sample than the total population when `with_replacement=false`

and I don't want a sample with replacement.

What would be the best way to do this with Polars?

like image 532
bretauv Avatar asked Dec 18 '25 15:12

bretauv


1 Answers

# increase repr defaults
pl.Config(fmt_table_cell_list_len=12, fmt_str_lengths=100)

You could use sample to just shuffle the list and then slice/head the result.

df.select(
    pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
      .list.sample(fraction=1, shuffle=True)
      .list.head(2)
)
shape: (2, 1)
┌───────────┐
│ val       │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [3]       │
│ [7, 4]    │
└───────────┘

when/then could then be used to choose the original list or the sample.

df.with_columns(
    pl.when(pl.col("val").list.len() > 5)
      .then(
          pl.concat_list(
              pl.col("val").list.head(2),
              pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
                .list.sample(fraction=1, shuffle=True)
                .list.head(2),
              pl.col("val").list.tail(2),
          )
      )
      .otherwise("val")
      .alias("sample")
)
shape: (2, 3)
┌─────┬─────────────────────────────────┬─────────────────────┐
│ grp ┆ val                             ┆ sample              │
│ --- ┆ ---                             ┆ ---                 │
│ str ┆ list[i64]                       ┆ list[i64]           │
╞═════╪═════════════════════════════════╪═════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]                 ┆ [1, 2, 3, 4, 5]     │
│ b   ┆ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ┆ [1, 2, 6, 5, 9, 10] │
└─────┴─────────────────────────────────┴─────────────────────┘

Alternatively, sample does accept expressions for n - so you could sample 0 items if the list is not large enough.

df.with_columns(
    pl.when(pl.col("val").list.len() > 5)
      .then(
          pl.concat_list(
              pl.col("val").list.head(2),
              pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
                .list.sample(n=pl.when(pl.col("val").list.len() > 5).then(2).otherwise(0)),
              pl.col("val").list.tail(2)
          )
      )
      .otherwise(pl.col("val"))
      .list.unique()
      .alias("sample")
)
shape: (2, 3)
┌─────┬─────────────────────────────────┬─────────────────────┐
│ grp ┆ val                             ┆ sample              │
│ --- ┆ ---                             ┆ ---                 │
│ str ┆ list[i64]                       ┆ list[i64]           │
╞═════╪═════════════════════════════════╪═════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]                 ┆ [1, 2, 3, 4, 5]     │
│ b   ┆ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ┆ [1, 2, 5, 7, 9, 10] │
└─────┴─────────────────────────────────┴─────────────────────┘
like image 127
jqurious Avatar answered Dec 21 '25 06:12

jqurious



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!