Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to define a custom aggregation function to sum a column of Vectors?

Tags:

I have a DataFrame of two columns, ID of type Int and Vec of type Vector (org.apache.spark.mllib.linalg.Vector).

The DataFrame looks like follow:

ID,Vec 1,[0,0,5] 1,[4,0,1] 1,[1,2,1] 2,[7,5,0] 2,[3,3,4] 3,[0,8,1] 3,[0,0,1] 3,[7,7,7] .... 

I would like to do a groupBy($"ID") then apply an aggregation on the rows inside each group by summing the vectors.

The desired output of the above example would be:

ID,SumOfVectors 1,[5,2,7] 2,[10,8,4] 3,[7,15,9] ... 

The available aggregation functions will not work, e.g. df.groupBy($"ID").agg(sum($"Vec") will lead to an ClassCastException.

How to implement a custom aggregation function that allows me to do the sum of vectors or arrays or any other custom operation?

like image 533
Rami Avatar asked Nov 24 '15 17:11

Rami


1 Answers

Spark >= 3.0

You can use Summarizer with sum

import org.apache.spark.ml.stat.Summarizer  df   .groupBy($"id")   .agg(Summarizer.sum($"vec").alias("vec")) 

Spark <= 3.0

Personally I wouldn't bother with UDAFs. There are more than verbose and not exactly fast (Spark UDAF with ArrayType as bufferSchema performance issues) Instead I would simply use reduceByKey / foldByKey:

import org.apache.spark.sql.Row import breeze.linalg.{DenseVector => BDV} import org.apache.spark.ml.linalg.{Vector, Vectors}  def dv(values: Double*): Vector = Vectors.dense(values.toArray)  val df = spark.createDataFrame(Seq(     (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)),     (2, dv(7,5,0)), (2, dv(3,3,4)),      (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7)))   ).toDF("id", "vec")  val aggregated = df   .rdd   .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) }   .foldByKey(BDV.zeros[Double](3))(_ += _)   .mapValues(v => Vectors.dense(v.toArray))   .toDF("id", "vec")  aggregated.show  // +---+--------------+ // | id|           vec| // +---+--------------+ // |  1| [5.0,2.0,7.0]| // |  2|[10.0,8.0,4.0]| // |  3|[7.0,15.0,9.0]| // +---+--------------+ 

And just for comparison a "simple" UDAF. Required imports:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer,   UserDefinedAggregateFunction} import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes} import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType} import org.apache.spark.sql.Row import scala.collection.mutable.WrappedArray 

Class definition:

class VectorSum (n: Int) extends UserDefinedAggregateFunction {     def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)     def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))     def dataType = SQLDataTypes.VectorType     def deterministic = true       def initialize(buffer: MutableAggregationBuffer) = {       buffer.update(0, Array.fill(n)(0.0))     }      def update(buffer: MutableAggregationBuffer, input: Row) = {       if (!input.isNullAt(0)) {         val buff = buffer.getAs[WrappedArray[Double]](0)          val v = input.getAs[Vector](0).toSparse         for (i <- v.indices) {           buff(i) += v(i)         }         buffer.update(0, buff)       }     }      def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {       val buff1 = buffer1.getAs[WrappedArray[Double]](0)        val buff2 = buffer2.getAs[WrappedArray[Double]](0)        for ((x, i) <- buff2.zipWithIndex) {         buff1(i) += x       }       buffer1.update(0, buff1)     }      def evaluate(buffer: Row) =  Vectors.dense(       buffer.getAs[Seq[Double]](0).toArray) }  

And an example usage:

df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show  // +---+--------------+ // | id|           vec| // +---+--------------+ // |  1| [5.0,2.0,7.0]| // |  2|[10.0,8.0,4.0]| // |  3|[7.0,15.0,9.0]| // +---+--------------+ 

See also: How to find mean of grouped Vector columns in Spark SQL?.

like image 89
zero323 Avatar answered Sep 20 '22 10:09

zero323