Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compare polars list to python list

Say I have this:

import polars

df = polars.DataFrame(dict(
  j=[1,2,3],
  k=[4,5,6],
  l=[7,8,9],
  ))
shape: (3, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 4   ┆ 7   │
│ 2   ┆ 5   ┆ 8   │
│ 3   ┆ 6   ┆ 9   │
└─────┴─────┴─────┘

I can filter for a particular row doing it one column at at time, i.e.:

df = df.filter(
  (polars.col('j') == 2) &
  (polars.col('k') == 5) &
  (polars.col('l') == 8)
  )
shape: (1, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 2   ┆ 5   ┆ 8   │
└─────┴─────┴─────┘

I'd like to compare to the list instead though (so I can avoid listing each column and to accommodate variable column DataFrames), e.g. something like:

df = df.filter(
    polars.concat_list(polars.all()) == [2, 5, 8]
    )

# exceptions.ArrowErrorException: NotYetImplemented("Casting from Int64 to LargeList(Field { name: \"item\", data_type: Int64, is_nullable: true, metadata: {} }) not supported")

Any ideas why the above is throwing the exception?

I can build the expression manually:

df = df.filter(
  functools.reduce(lambda a, e: a & e, (polars.col(c) == v for c, v in zip(df.columns, [2, 5, 8])))
  )

but I was hoping there's a way to compare lists directly - e.g. as if I had this DataFrame originally:

df = polars.DataFrame(dict(j=[
  [1,4,7],
  [2,5,8],
  [3,6,9],
  ]))
shape: (3, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [1, 4, 7] │
│ [2, 5, 8] │
│ [3, 6, 9] │
└───────────┘

and wanted to find the row which matches [2, 5, 8]. Any hints?

like image 518
levant pied Avatar asked Mar 26 '26 15:03

levant pied


1 Answers

You can pass multiple conditions to .all_horizontal() instead of functools.reduce

For a list column, you can compare the values at each index with .list.get():

df.filter(
   pl.all_horizontal(
      pl.col("j").list.get(n) == row[n]
      for row in [[2, 5, 8]]
      for n in range(len(row))
   )
)
shape: (1, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [2, 5, 8] │
└───────────┘

I'm not sure why this doesn't work:

df.filter(pl.col("j") == pl.lit([[2, 5, 8]]))
shape: (0, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
└───────────┘

For regular columns, you could modify your example:

df.filter(
   pl.all_horizontal(
      pl.col(col) == value
      for col, value in dict(zip(df.columns, [2, 5, 8])).items()
   )
)
like image 153
jqurious Avatar answered Mar 28 '26 06:03

jqurious



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!