Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apache Spark - How to define a UserDefinedAggregateFunction after 3?

I'm using Spark 3.0, and in order to use a user-defined function as a window function, I need an instance of UserDefinedAggregateFunction. Initially I thought that using the new Aggregator and udaf would solve this problem (as shown here), but udaf returns a UserDefinedFunction, not a UserDefinedAggregateFunction.

Since Spark 3.0, UserDefinedAggregateFunction is deprecated, as stated here (despite being possible to still find it around)

So the question is: is there a correct (not deprecated) way in Spark 3.0 to define a proper UserDefinedAggregateFunction and use it as a window function?

like image 246
PiFace Avatar asked May 23 '26 09:05

PiFace


1 Answers

In Spark 3, the new API uses Aggregator to define user-defined aggregations:

abstract class Aggregator[-IN, BUF, OUT] extends Serializable:

A base class for user-defined aggregations, which can be used in Dataset operations to take all of the elements of a group and reduce them to a single value.

Aggregator brings performance improvements comparing to deprecated UDAF. You can see the issue Efficient User Defined Aggregators.

Here's an example on how to define a mean Aggregator and register it using functions.udaf method:

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator

val meanAgg= new Aggregator[Long, (Long, Long), Double]() {
    
    def zero = (0L, 0L) // Init the buffer
  
    def reduce(y: (Long, Long), x: Long) = (y._1 + x, y._2 + 1)

    def merge(a: (Long, Long), b: (Long, Long)) = (a._1 + b._1, a._2 + b._2)

    def finish(r: (Long, Long)) = r._1.toDouble / r._2
  
    def bufferEncoder: Encoder[(Long, Long)] = implicitly(ExpressionEncoder[(Long, Long)])

    def outputEncoder: Encoder[Double] = implicitly(ExpressionEncoder[Double])
}

val meanUdaf = udaf(meanAgg)

Using it with Window:

val df = Seq(
  (1, 2), (1, 5),
  (2, 3), (2, 1),
).toDF("id", "value")
    
df.withColumn("mean", meanUdaf($"value").over(Window.partitionBy($"id"))).show
//+---+-----+----+
//| id|value|mean|
//+---+-----+----+
//|  1|    2| 3.5|
//|  1|    5| 3.5|
//|  2|    3| 2.0|
//|  2|    1| 2.0|
//+---+-----+----+
like image 175
blackbishop Avatar answered May 25 '26 10:05

blackbishop



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!