Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find min value for every 5 hour interval

My df

val df = Seq(
  ("1", 1),
  ("1", 1),
  ("1", 2),
  ("1", 4),
  ("1", 5),
  ("1", 6),
  ("1", 8),
  ("1", 12),
  ("1", 12),
  ("1", 13),
  ("1", 14),
  ("1", 15),
  ("1", 16)
).toDF("id", "time")

For this case the first interval starts from 1 hour. So every row up to 6 (1 + 5) is part of this interval.

But 8 - 1 > 5, so the second interval starts from 8 and goes up to 13.

Then I see that 14 - 8 > 5, so the third one starts and so on.

The desired result

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |8       |
|1  |12  |8       |
|1  |13  |8       |
|1  |14  |14      |
|1  |15  |14      |
|1  |16  |14      |
+---+----+--------+

I'm trying to do it using min function, but don't know how to account for this condition.

val window = Window.partitionBy($"id").orderBy($"time")
df
.select($"id", $"time")
.withColumn("min_time", when(($"time" - min($"time").over(window)) <= 5, min($"time").over(window)).otherwise($"time"))
.show(false)

what I get

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |12      |
|1  |12  |12      |
|1  |13  |13      |
|1  |14  |14      |
|1  |15  |15      |
|1  |16  |16      |
+---+----+--------+
like image 417
gjin Avatar asked Jan 27 '26 02:01

gjin


1 Answers

You can go with your first idea of using aggregation function on a window. But instead of using some combination of Spark's already defined functions, you can define your own Spark's user-defined aggregate function (UDAF).

Analysis

As you correctly supposed, we should use a kind of min function on a window. On the rows of this window, we want to implement the following rule:

Given rows sorted by time, if the difference between the min_time of the previous row and the time of the current row is greater than 5, then the current row's min_time should be current row's time, else the current row's min_time should be previous row's min_time.

However, with the aggregate functions provided by Spark, we can't access to the previous row's min_time. It exists a lag function, but with this function we can only access to the already present values of previous rows. As the previous row's min_time is not already present, we can't access it.

Thus we have to define our own aggregate function

Solution

Defining a tailor-made aggregate function

To define our aggregate function, we need to create a class that extends the Aggregator abstract class. Below is the complete implementation:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

object MinByInterval extends Aggregator[Integer, Integer, Integer] {

  def zero: Integer = null

  def reduce(buffer: Integer, time: Integer): Integer = {
    if (buffer == null || time - buffer > 5) time else buffer
  }

  def merge(b1: Integer, b2: Integer): Integer = {
    throw new NotImplementedError("should not use as general aggregation")
  }

  def finish(reduction: Integer): Integer = reduction

  def bufferEncoder: Encoder[Integer] = Encoders.INT

  def outputEncoder: Encoder[Integer] = Encoders.INT

}

We use Integer for input, buffer and output types. We chose Integer as it is a nullable Int. We could have used Option[Int], however the documentation of Spark advises to not recreate objects in aggregators methods for performance issues, what would happens if we use complex types like Option.

We implement the rule defined in Analysis section in reduce method:

def reduce(buffer: Integer, time: Integer): Integer = {
  if (buffer == null || time - buffer > 5) time else buffer
}

Here time is the value in the column time of the current row, and buffer the value previously computed, so corresponding to the column min_time of the previous row. As in our window we sort the rows by time, time is always greater than buffer. The null buffer case only happens when treating first row.

The method merge is not used when using aggregate function over a window, so we don't implement it.

finish method is identity method as we don't need to perform final calculation on our aggregated value and output and buffer encoders are Encoders.INT

Calling user defined aggregate function

Now we can call our user defined aggregate function with the following code:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val minTime = udaf(MinByInterval)
val window = Window.partitionBy("id").orderBy("time")
df.withColumn("min_time", minTime(col("time")).over(window))

Run

Given the input dataframe in the question, we get:

+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1  |1   |1       |
|1  |1   |1       |
|1  |2   |1       |
|1  |4   |1       |
|1  |5   |1       |
|1  |6   |1       |
|1  |8   |8       |
|1  |12  |8       |
|1  |12  |8       |
|1  |13  |8       |
|1  |14  |14      |
|1  |15  |14      |
|1  |16  |14      |
+---+----+--------+
like image 91
Vincent Doba Avatar answered Jan 29 '26 15:01

Vincent Doba