I have found that in Apache Spark SQL (version 2.2.0), when a user-defined aggregate function (UDAF) that is used over a window specification is supplied with multiple rows of identical input, the UDAF does (seemingly) not call the evaluate
method correctly.
I have been able to reproduce this behavior both in Java and Scala, locally and on a cluster. The code below shows an example where rows are marked as false if they are within 1 second of the previous row.
class ExampleUDAF(val timeLimit: Long) extends UserDefinedAggregateFunction {
def deterministic: Boolean = true
def inputSchema: StructType = StructType(Array(StructField("unix_time", LongType)))
def dataType: DataType = BooleanType
def bufferSchema = StructType(Array(
StructField("previousKeepTime", LongType),
StructField("keepRow", BooleanType)
))
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = 0L
buffer(1) = false
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (buffer(0) == 0L) {
buffer(0) = input.getLong(0)
buffer(1) = true
} else {
val timeDiff = input.getLong(0) - buffer.getLong(0)
if (timeDiff < timeLimit) {
buffer(1) = false
} else {
buffer(0) = input.getLong(0)
buffer(1) = true
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {} // Not implemented
def evaluate(buffer: Row): Boolean = buffer.getBoolean(1)
}
val timeLimit = 1000 // 1 second
val udaf = new ExampleUDAF(timeLimit)
val window = Window
.orderBy(column("unix_time"))
.partitionBy(column("category"))
val df = spark.createDataFrame(Arrays.asList(
Row(1510000001000L, "a", true),
Row(1510000001000L, "a", false),
Row(1510000001000L, "a", false),
Row(1510000001000L, "a", false),
Row(1510000700000L, "a", true),
Row(1510000700000L, "a", false)
), new StructType().add("unix_time", LongType).add("category", StringType).add("expected_result", BooleanType))
df.withColumn("actual_result", udaf(column("unix_time")).over(window)).show
Below is the output of running the code above. The first row is expected to have a actual_result
value of true, as there is no prior data. When the unix_time
input is modified to have 1 millisecond between each record, the UDAF works as expected.
Adding print statements in the UDAF methods revealed that evaluate
is only called once, at the end, and that buffer was correctly updated to true in the update
method, but this is not what is returned after the completion of the UDAF.
+-------------+--------+---------------+-------------+
| unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000| a| true| false| // Should true as first element
|1510000001000| a| false| false|
|1510000001000| a| false| false|
|1510000001000| a| false| false|
|1510000700000| a| true| false| // Should be true as more than 1000 milliseconds between self and previous
|1510000700000| a| false| false|
+-------------+--------+---------------+-------------+
I am understanding Spark's UDAF behavior correctly when used over window specifications? If not, could anyone offer any insight in this area. If my understanding of UDAF behaviour over windows is correct, could this be a bug in Spark? Thank you.
One problem with your UDAF is that it does not specify on which rows you want to run your window with rowsBetween()
. If there is no rowsBetween()
specification, for each row the window function will take all (See update below) rows before and after current one including current one (in given category). So the actual_result
for all rows will basically take into account only two last rows in your example DataFrame
, with unix_time=1510000700000
which effectively will return false
for all rows.
With a window
declaration like this:
Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)
You are always looking only on previous row and current row. With previous row taken first. This creates correct output. But since the ordering of rows with the same unix_time
is not unique, it is not possible to predict which row will have value true
among rows with identical unix_time
.
The result could look like this:
+-------------+--------+---------------+-------------+
| unix_time|category|expected_result|actual_result|
+-------------+--------+---------------+-------------+
|1510000001000| a| false| true|
|1510000001000| a| false| false|
|1510000001000| a| false| false|
|1510000001000| a| true| false|
|1510000700000| a| true| true|
|1510000700000| a| false| false|
+-------------+--------+---------------+-------------+
Update
After investigating further, it seems that when orderBy
column is provided it takes all elements before current row + current row. Not all elements of partition like I said before. In addition, if orderBy column
contains duplicate values window for each duplicated row will contain all duplicated values. You can see it clearly by doing:
val wA = Window.partitionBy(col("category")).orderBy(col("unix_time"))
val wB = Window.partitionBy(col("category"))
val wC = Window.partitionBy(col("category")).orderBy(col("unix_time")).rowsBetween(-1L, 0L)
df.withColumn("countRows", count(col("unix_time")).over(wA)).show()
df.withColumn("countRows", count(col("unix_time")).over(wB)).show()
df.withColumn("countRows", count(col("unix_time")).over(wC)).show()
which will count number of elements in each window.
wA
will have 4 elements in each 1510000001000 row and 6 elements for every 1510000700000.wB
when there is no orderBy
all rows are included in the window for each partition, so all windows will have 6 elements.wC
specifies the selection of rows, so does not leave ambiguity which rows are selected for which window. There is only 1 element for first row and 2 elements in windows of all subsequent rows. Which produces correct result.I learned something new today too :)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With