I have a DataFrame of two columns, ID
of type Int
and Vec
of type Vector
(org.apache.spark.mllib.linalg.Vector
).
The DataFrame looks like follow:
ID,Vec 1,[0,0,5] 1,[4,0,1] 1,[1,2,1] 2,[7,5,0] 2,[3,3,4] 3,[0,8,1] 3,[0,0,1] 3,[7,7,7] ....
I would like to do a groupBy($"ID")
then apply an aggregation on the rows inside each group by summing the vectors.
The desired output of the above example would be:
ID,SumOfVectors 1,[5,2,7] 2,[10,8,4] 3,[7,15,9] ...
The available aggregation functions will not work, e.g. df.groupBy($"ID").agg(sum($"Vec")
will lead to an ClassCastException.
How to implement a custom aggregation function that allows me to do the sum of vectors or arrays or any other custom operation?
Spark >= 3.0
You can use Summarizer
with sum
import org.apache.spark.ml.stat.Summarizer df .groupBy($"id") .agg(Summarizer.sum($"vec").alias("vec"))
Spark <= 3.0
Personally I wouldn't bother with UDAFs. There are more than verbose and not exactly fast (Spark UDAF with ArrayType as bufferSchema performance issues) Instead I would simply use reduceByKey
/ foldByKey
:
import org.apache.spark.sql.Row import breeze.linalg.{DenseVector => BDV} import org.apache.spark.ml.linalg.{Vector, Vectors} def dv(values: Double*): Vector = Vectors.dense(values.toArray) val df = spark.createDataFrame(Seq( (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)), (2, dv(7,5,0)), (2, dv(3,3,4)), (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7))) ).toDF("id", "vec") val aggregated = df .rdd .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) } .foldByKey(BDV.zeros[Double](3))(_ += _) .mapValues(v => Vectors.dense(v.toArray)) .toDF("id", "vec") aggregated.show // +---+--------------+ // | id| vec| // +---+--------------+ // | 1| [5.0,2.0,7.0]| // | 2|[10.0,8.0,4.0]| // | 3|[7.0,15.0,9.0]| // +---+--------------+
And just for comparison a "simple" UDAF. Required imports:
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes} import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType} import org.apache.spark.sql.Row import scala.collection.mutable.WrappedArray
Class definition:
class VectorSum (n: Int) extends UserDefinedAggregateFunction { def inputSchema = new StructType().add("v", SQLDataTypes.VectorType) def bufferSchema = new StructType().add("buff", ArrayType(DoubleType)) def dataType = SQLDataTypes.VectorType def deterministic = true def initialize(buffer: MutableAggregationBuffer) = { buffer.update(0, Array.fill(n)(0.0)) } def update(buffer: MutableAggregationBuffer, input: Row) = { if (!input.isNullAt(0)) { val buff = buffer.getAs[WrappedArray[Double]](0) val v = input.getAs[Vector](0).toSparse for (i <- v.indices) { buff(i) += v(i) } buffer.update(0, buff) } } def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { val buff1 = buffer1.getAs[WrappedArray[Double]](0) val buff2 = buffer2.getAs[WrappedArray[Double]](0) for ((x, i) <- buff2.zipWithIndex) { buff1(i) += x } buffer1.update(0, buff1) } def evaluate(buffer: Row) = Vectors.dense( buffer.getAs[Seq[Double]](0).toArray) }
And an example usage:
df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show // +---+--------------+ // | id| vec| // +---+--------------+ // | 1| [5.0,2.0,7.0]| // | 2|[10.0,8.0,4.0]| // | 3|[7.0,15.0,9.0]| // +---+--------------+
See also: How to find mean of grouped Vector columns in Spark SQL?.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With