Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Apache Spark SQL UDAF over window showing odd behaviour with duplicate input

I have found that in Apache Spark SQL (version 2.2.0), when a user-defined aggregate function (UDAF) that is used over a window specification is supplied with multiple rows of identical input, the UDAF does (seemingly) not call the evaluate method correctly.

I have been able to reproduce this behavior both in Java and Scala, locally and on a cluster. The code below shows an example where rows are marked as false if they are within 1 second of the previous row.

class ExampleUDAF(val timeLimit: Long) extends UserDefinedAggregateFunction {
  def deterministic: Boolean = true
  def inputSchema: StructType = StructType(Array(StructField("unix_time", LongType)))
  def dataType: DataType = BooleanType

  def bufferSchema = StructType(Array(
    StructField("previousKeepTime", LongType),
    StructField("keepRow", BooleanType)
  ))

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0L
    buffer(1) = false
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {    
    if (buffer(0) == 0L) {
      buffer(0) = input.getLong(0)
      buffer(1) = true
    } else {
      val timeDiff = input.getLong(0) - buffer.getLong(0)

      if (timeDiff < timeLimit) {
        buffer(1) = false
      } else {
        buffer(0) = input.getLong(0)
        buffer(1) = true
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {} // Not implemented
  def evaluate(buffer: Row): Boolean = buffer.getBoolean(1)
 }

val timeLimit = 1000 // 1 second
val udaf = new ExampleUDAF(timeLimit)

val window = Window
  .orderBy(column("unix_time"))
  .partitionBy(column("category"))

val df = spark.createDataFrame(Arrays.asList(
    Row(1510000001000L, "a", true), 
    Row(1510000001000L, "a", false), 
    Row(1510000001000L, "a", false),
    Row(1510000001000L, "a", false),
    Row(1510000700000L, "a", true),
    Row(1510000700000L, "a", false)
  ), new StructType().add("unix_time", LongType).add("category", StringType).add("expected_result", BooleanType))

df.withColumn("actual_result", udaf(column("unix_time")).over(window)).show

Below is the output of running the code above. The first row is expected to have a actual_result value of true, as there is no prior data. When the unix_time input is modified to have 1 millisecond between each record, the UDAF works as expected.

Adding print statements in the UDAF methods revealed that evaluate is only called once, at the end, and that buffer was correctly updated to true in the update method, but this is not what is returned after the completion of the UDAF.

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|           true|        false|  // Should true as first element
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000700000|       a|           true|        false|  // Should be true as more than 1000 milliseconds between self and previous
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+

I am understanding Spark's UDAF behavior correctly when used over window specifications? If not, could anyone offer any insight in this area. If my understanding of UDAF behaviour over windows is correct, could this be a bug in Spark? Thank you.

like image 429
ab853 Avatar asked Nov 29 '17 14:11

ab853


1 Answers

One problem with your UDAF is that it does not specify on which rows you want to run your window with rowsBetween(). If there is no rowsBetween() specification, for each row the window function will take all (See update below) rows before and after current one including current one (in given category). So the actual_result for all rows will basically take into account only two last rows in your example DataFrame, with unix_time=1510000700000 which effectively will return false for all rows.

With a window declaration like this:

Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)

You are always looking only on previous row and current row. With previous row taken first. This creates correct output. But since the ordering of rows with the same unix_time is not unique, it is not possible to predict which row will have value true among rows with identical unix_time.

The result could look like this:

+-------------+--------+---------------+-------------+
|    unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000|       a|          false|         true|
|1510000001000|       a|          false|        false|
|1510000001000|       a|          false|        false|
|1510000001000|       a|           true|        false|
|1510000700000|       a|           true|         true|
|1510000700000|       a|          false|        false|
+-------------+--------+---------------+-------------+

Update

After investigating further, it seems that when orderBy column is provided it takes all elements before current row + current row. Not all elements of partition like I said before. In addition, if orderBy column contains duplicate values window for each duplicated row will contain all duplicated values. You can see it clearly by doing:

val wA = Window.partitionBy(col("category")).orderBy(col("unix_time"))
val wB = Window.partitionBy(col("category"))
val wC = Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)

df.withColumn("countRows", count(col("unix_time")).over(wA)).show()
df.withColumn("countRows", count(col("unix_time")).over(wB)).show()
df.withColumn("countRows", count(col("unix_time")).over(wC)).show()

which will count number of elements in each window.

  • Window wA will have 4 elements in each 1510000001000 row and 6 elements for every 1510000700000.
  • For wB when there is no orderBy all rows are included in the window for each partition, so all windows will have 6 elements.
  • The last wC specifies the selection of rows, so does not leave ambiguity which rows are selected for which window. There is only 1 element for first row and 2 elements in windows of all subsequent rows. Which produces correct result.

I learned something new today too :)

like image 102
astro_asz Avatar answered Nov 08 '22 23:11

astro_asz