Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark window function with condition on current row

Tags:

I am trying to count for a given order_id how many orders there were in the past 365 days which had a payment. And this is not the problem: I use the window function.

Where it gets tricky for me is: I don't want to count orders in this time window where the payment_date is after order_date of the current order_id.

Currently, I have something like this:

val window: WindowSpec = Window
  .partitionBy("customer_id")
  .orderBy("order_date")
  .rangeBetween(-365*days, -1)

and

df.withColumn("paid_order_count", count("*") over window)

which would count all orders for the customer within the last 365 days before his current order.

How can I now incorporate a condition for the counting that takes the order_date of the current order into account?

Example:

+---------+-----------+-------------+------------+
|order_id |order_date |payment_date |customer_id |
+---------+-----------+-------------+------------+
|1        |2017-01-01 |2017-01-10   |A           |
|2        |2017-02-01 |2017-02-10   |A           |
|3        |2017-02-02 |2017-02-20   |A           |

The resulting table should look like this:

+---------+-----------+-------------+------------+-----------------+
|order_id |order_date |payment_date |customer_id |paid_order_count |
+---------+-----------+-------------+------------+-----------------+
|1        |2017-01-01 |2017-01-10   |A           |0                |
|2        |2017-02-01 |2017-02-10   |A           |1                |
|3        |2017-02-02 |2017-02-20   |A           |1                |

For order_id = 3 the paid_order_count should not be 2 but 1 as order_id = 2 is paid after order_id = 3 is placed.

I hope that I explained my problem well and look forward to your ideas!

like image 671
Siruphuhn Avatar asked Oct 19 '18 15:10

Siruphuhn


1 Answers

Very good question!!! A couple of remarks, using rangeBetween creates a fixed frame that is based on number of rows in it and not on values, so it will be problematic in 2 cases:

  1. customer does not have orders every single day, so 365 rows window might contain rows with order_date well before one year ago
  2. if customer has more than one order per day, it will mess with the one year coverage
  3. combination of the 1 and 2

Also rangeBetween does not work with Date and Timestamp datatypes.

To solve it, it is possible to use window function with lists and an UDF:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

  val df = spark.sparkContext.parallelize(Seq(
    (1, "2017-01-01", "2017-01-10", "A")
    , (2, "2017-02-01", "2017-02-10", "A")
    , (3, "2017-02-02", "2017-02-20", "A")
  )
  ).toDF("order_id", "order_date", "payment_date", "customer_id")
    .withColumn("order_date_ts", to_timestamp($"order_date", "yyyy-MM-dd").cast("long"))
    .withColumn("payment_date_ts", to_timestamp($"payment_date", "yyyy-MM-dd").cast("long"))

//      df.printSchema()
//      df.show(false)

  val window = Window.partitionBy("customer_id").orderBy("order_date_ts").rangeBetween(Window.unboundedPreceding, -1)

  val count_filtered_dates = udf( (days: Int, top: Long, array: Seq[Long]) => {
      val bottom = top - (days * 60 * 60 * 24).toLong // in spark timestamps are in secconds, calculating the date days ago
      array.count(v => v >= bottom && v < top)
    }
  )

  val res = df.withColumn("paid_orders", collect_list("payment_date_ts") over window)
      .withColumn("paid_order_count", count_filtered_dates(lit(365), $"order_date_ts", $"paid_orders"))

  res.show(false)

Output:

+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+
|order_id|order_date|payment_date|customer_id|order_date_ts|payment_date_ts|paid_orders             |paid_order_count|
+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+
|1       |2017-01-01|2017-01-10  |A          |1483228800   |1484006400     |[]                      |0               |
|2       |2017-02-01|2017-02-10  |A          |1485907200   |1486684800     |[1484006400]            |1               |
|3       |2017-02-02|2017-02-20  |A          |1485993600   |1487548800     |[1484006400, 1486684800]|1               |
+--------+----------+------------+-----------+-------------+---------------+------------------------+----------------+

Converting dates to Spark timestamps in seconds makes the lists more memory efficient.

It is the easiest code to implement, but not the most optimal as the lists will take up some memory, custom UDAF would be best, but requires more coding, might do later. This will still work if you have thousands of orders per customer.

like image 92
alexeipab Avatar answered Nov 15 '22 04:11

alexeipab