Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python Polars: How to apply a aggregate function for all columns and pass one additional column as argument?

I have a lazy dataframe (using scan_parquet) like below,

df = pl.from_repr("""
┌────────┬──────┬──────┬──────┬──────┐
│ region ┆ time ┆ sen1 ┆ sen2 ┆ sen3 │
│ ---    ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
│ str    ┆ i64  ┆ f64  ┆ f64  ┆ f64  │
╞════════╪══════╪══════╪══════╪══════╡
│ us     ┆ 1    ┆ 10.0 ┆ 11.0 ┆ 12.0 │
│ us     ┆ 2    ┆ 11.0 ┆ 14.0 ┆ 13.0 │
│ us     ┆ 3    ┆ 10.1 ┆ 10.0 ┆ 12.3 │
│ us     ┆ 4    ┆ 13.0 ┆ 11.1 ┆ 14.0 │
│ us     ┆ 5    ┆ 12.0 ┆ 11.0 ┆ 19.0 │
│ uk     ┆ 1    ┆ 10.0 ┆ 11.0 ┆ 12.1 │
│ uk     ┆ 2    ┆ 11.0 ┆ 14.0 ┆ 13.0 │
│ uk     ┆ 3    ┆ 10.1 ┆ 10.0 ┆ 12.0 │
│ uk     ┆ 4    ┆ 13.0 ┆ 11.1 ┆ 14.0 │
│ uk     ┆ 5    ┆ 12.0 ┆ 11.0 ┆ 19.0 │
│ uk     ┆ 6    ┆ 13.7 ┆ 11.1 ┆ 14.0 │
│ uk     ┆ 7    ┆ 12.0 ┆ 11.0 ┆ 21.9 │
└────────┴──────┴──────┴──────┴──────┘
""")

I want to find max and min for all the sensors for each region and while doing so, I also wanted the time at which max and min happened.

So, I wrote the below aggregate function,

def my_custom_agg(t,v):
   smax = v.max()
   smin = v.min()
   smax_t = t[v.arg_max()]
   smin_t = t[v.arg_max()]
   return [smax, smin, smax_t, smin_t]

Then I did the groupby as below,

df.group_by('region').agg(
    pl.all().map_elements(lambda s: my_custom_agg(pl.col('time'),s))
)

When I do this, I get the below error,

TypeError: 'Expr' object is not subscribable

Expected result,

region sen1              sen2              sen3
us     [13.0,10.0,4,1]   [14.0,10.0,2,3]   [19.0,12.0,5,1]
uk     [13.7,10.0,6,1]   [14.0,10.0,2,3]   [21.9,12.0,7,3]

# which I will unpivot and transform to below,
region   sname  smax  smin smax_t  smin_t
us       sen1   13.0  10.0 4       1
us       sen2   14.0  10.0 2       3
us       sen3   19.0  12.0 5       1
uk       sen1   13.7  10.0 6       1
uk       sen2   14.0  10.0 2       3
uk       sen3   21.9  12.0 7       3

Could you please tell me how to pass one additional column as an argument? If there is an alternative way to do this, I am happy to hear it since I am flexible with the output format.

Note: In my real dataset I have 8k sensors, so it is better to do with pl.all().

Thanks for your support.

like image 344
Selva Avatar asked Sep 19 '25 13:09

Selva


1 Answers

You can do with just native Polars expressions and avoid the need for map_elements

If you first reshape with .unpivot()

df.unpivot(index=["region", "time"], variable_name="sname")
shape: (36, 4)
┌────────┬──────┬───────┬───────┐
│ region ┆ time ┆ sname ┆ value │
│ ---    ┆ ---  ┆ ---   ┆ ---   │
│ str    ┆ i64  ┆ str   ┆ f64   │
╞════════╪══════╪═══════╪═══════╡
│ us     ┆ 1    ┆ sen1  ┆ 10.0  │
│ us     ┆ 2    ┆ sen1  ┆ 11.0  │
│ us     ┆ 3    ┆ sen1  ┆ 10.1  │
│ us     ┆ 4    ┆ sen1  ┆ 13.0  │
│ us     ┆ 5    ┆ sen1  ┆ 12.0  │
│ …      ┆ …    ┆ …     ┆ …     │
│ uk     ┆ 3    ┆ sen3  ┆ 12.0  │
│ uk     ┆ 4    ┆ sen3  ┆ 14.0  │
│ uk     ┆ 5    ┆ sen3  ┆ 19.0  │
│ uk     ┆ 6    ┆ sen3  ┆ 14.0  │
│ uk     ┆ 7    ┆ sen3  ┆ 21.9  │
└────────┴──────┴───────┴───────┘

It is then just a regular group_by / agg

(
   df
   .unpivot(index=["region", "time"], variable_name="sname")
   .group_by("region", "sname")
   .agg(
       pl.all().get(pl.col("value").arg_min()).name.suffix("_min"),
       pl.all().get(pl.col("value").arg_max()).name.suffix("_max")
   )
)
shape: (6, 6)
┌────────┬───────┬──────────┬───────────┬──────────┬───────────┐
│ region ┆ sname ┆ time_min ┆ value_min ┆ time_max ┆ value_max │
│ ---    ┆ ---   ┆ ---      ┆ ---       ┆ ---      ┆ ---       │
│ str    ┆ str   ┆ i64      ┆ f64       ┆ i64      ┆ f64       │
╞════════╪═══════╪══════════╪═══════════╪══════════╪═══════════╡
│ uk     ┆ sen1  ┆ 1        ┆ 10.0      ┆ 6        ┆ 13.7      │
│ us     ┆ sen1  ┆ 1        ┆ 10.0      ┆ 4        ┆ 13.0      │
│ uk     ┆ sen3  ┆ 3        ┆ 12.0      ┆ 7        ┆ 21.9      │
│ us     ┆ sen2  ┆ 3        ┆ 10.0      ┆ 2        ┆ 14.0      │
│ us     ┆ sen3  ┆ 1        ┆ 12.0      ┆ 5        ┆ 19.0      │
│ uk     ┆ sen2  ┆ 3        ┆ 10.0      ┆ 2        ┆ 14.0      │
└────────┴───────┴──────────┴───────────┴──────────┴───────────┘

We .get() values from the .arg_min() and .arg_max() indexes.

And .name.suffix() to create the new column names.

like image 107
jqurious Avatar answered Sep 22 '25 03:09

jqurious