Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why Mutable map becomes immutable automatically in UserDefinedAggregateFunction(UDAF) in Spark

I am trying to define a UserDefinedAggregateFunction(UDAF) in Spark, which counts the number of occurrences for each unique values in a column of a group.

This is an example: Suppose I have a dataframe df like this,

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+

I will have a UDAF DistinctValues

val func = new DistinctValues

Then I apply it to the dataframe df

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV"))

I am expecting to have something like this:

+----+--------------------------+
|col1|DV                        |
+----+--------------------------+
|   a|  Map(a1->2, a2->1)       |
|   b|  Map(b1->3, b2->1, b3->1)|
+----+--------------------------+

So I came out with a UDAF like this,

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}

Then I have this function on my dataframe,

val func = new DistinctValues
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV"))

It gave such error,

func: DistinctValues = $iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
at $iwC$$iwC$DistinctValues.update(<console>:39)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)

It looks like in the update(buffer: MutableAggregationBuffer, input: Row) method, the variable buffer is a immutable.Map, the program tired to cast it to mutable.Map,

But I used mutable.Map to initialize buffer variable in initialize(buffer: MutableAggregationBuffer, input:Row) method. Is it the same variable passed to update method? And also buffer is mutableAggregationBuffer, so it should be mutable, right?

Why my mutable.Map became immutable? Does anyone know what happened?

I really need a mutable Map in this function to complete the task. I know there is a workaround to create a mutable map from the immutable map, then update it. But I really want to know why the mutable one transforms to immutable one in the program automatically, it doesn't make sense to me.

like image 800
Fan L. Avatar asked Apr 14 '16 17:04

Fan L.


1 Answers

Believe it is the MapType in your StructType. buffer therefore holds a Map, which would be immutable.

You can convert it, but why don't you just leave it immutable and do this:

mp = mp + (k -> c)

to add an entry to the immutable Map?

Working example below:

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp = mp  + (str -> c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[Map[String, Long]](0)
    var mp2 = buffer2.getAs[Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1 = mp1 + (k -> c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[Map[String, LongType]](0)
  }
}
like image 101
David Griffin Avatar answered Oct 08 '22 16:10

David Griffin