Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apply a custom Spark Aggregator on multiple columns (Spark 2.0)

I've created a custom Aggregator[] for Strings.

I would like to apply it on all columns of a DataFrame where all columns are strings, but column number is arbitrary.

I'm stuck in writing the right expression. I would like to write something like this :

df.agg( df.columns.map( c => myagg(df(c)) ) : _*) 

which is obviously wrong given the various interfaces.

I had a look to RelationalGroupedDataset.agg(expr: Column, exprs: Column*) code, but I'm not familiar with expression manipulation.

Any idea ?

like image 313
mathieu Avatar asked Dec 19 '22 10:12

mathieu


1 Answers

In contrast to UserDefinedAggregateFunctions, which operate on individual fields (columns), Aggregtors expects a complete Row / value.

If you want and Aggregator which can be used as in your snippet it has to be parametrized by the column name and use Row as a value type.

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, Row}

case class Max(col: String) 
    extends Aggregator[Row, Int, Int] with Serializable {

  def zero = Int.MinValue
  def reduce(acc: Int, x: Row) =
    Math.max(acc, Option(x.getAs[Int](col)).getOrElse(zero))

  def merge(acc1: Int, acc2: Int) = Math.max(acc1, acc2)
  def finish(acc: Int) = acc

  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
  def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

Example usage:

val df = Seq((1, None, 3), (4, Some(5), -6)).toDF("x", "y", "z")

@transient val exprs = df.columns.map(c => Max(c).toColumn.alias(s"max($c)"))

df.agg(exprs.head, exprs.tail: _*)
+------+------+------+
|max(x)|max(y)|max(z)|
+------+------+------+
|     4|     5|     3|
+------+------+------+

Arguably Aggregators make much more sense when combined with statically typed Datasets than Dataset<Row>.

Depending on your requirements you could also aggregate multiple columns on a single pass using Seq[_] accumulator and processing a whole Row (record) in a single merge call.

like image 154
zero323 Avatar answered Jan 13 '23 09:01

zero323