This is probably easiest to explain through example. Suppose I have a DataFrame of user logins to a website, for instance:
scala> df.show(5)
+----------------+----------+
| user_name|login_date|
+----------------+----------+
|SirChillingtonIV|2012-01-04|
|Booooooo99900098|2012-01-04|
|Booooooo99900098|2012-01-06|
| OprahWinfreyJr|2012-01-10|
|SirChillingtonIV|2012-01-11|
+----------------+----------+
only showing top 5 rows
I would like to add to this a column indicating when they became an active user on the site. But there is one caveat: there is a time period during which a user is considered active, and after this period, if they log in again, their became_active
date resets. Suppose this period is 5 days. Then the desired table derived from the above table would be something like this:
+----------------+----------+-------------+
| user_name|login_date|became_active|
+----------------+----------+-------------+
|SirChillingtonIV|2012-01-04| 2012-01-04|
|Booooooo99900098|2012-01-04| 2012-01-04|
|Booooooo99900098|2012-01-06| 2012-01-04|
| OprahWinfreyJr|2012-01-10| 2012-01-10|
|SirChillingtonIV|2012-01-11| 2012-01-11|
+----------------+----------+-------------+
So, in particular, SirChillingtonIV's became_active
date was reset because their second login came after the active period expired, but Booooooo99900098's became_active
date was not reset the second time he/she logged in, because it fell within the active period.
My initial thought was to use window functions with lag
, and then using the lag
ged values to fill the became_active
column; for instance, something starting roughly like:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val window = Window.partitionBy("user_name").orderBy("login_date")
val df2 = df.withColumn("tmp", lag("login_date", 1).over(window))
Then, the rule to fill in the became_active
date would be, if tmp
is null
(i.e., if it's the first ever login) or if login_date - tmp >= 5
then became_active = login_date
; otherwise, go to the next most recent value in tmp
and apply the same rule. This suggests a recursive approach, which I'm having trouble imagining a way to implement.
My questions: Is this a viable approach, and if so, how can I "go back" and look at earlier values of tmp
until I find one where I stop? I can't, to my knowledge, iterate through values of a Spark SQL Column
. Is there another way to achieve this result?
Spark SQL supports three kinds of window functions: ranking functions, analytic functions, and aggregate functions. The available ranking functions and analytic functions are summarized in the table below. For aggregate functions, users can use any existing aggregate function as a window function.
Spark Window functions are used to calculate results such as the rank, row number e.t.c over a range of input rows and these are available to you by importing org. apache. spark. sql.
unboundedPreceding , or any value less than or equal to -9223372036854775808. endint. boundary end, inclusive. The frame is unbounded if this is Window. unboundedFollowing , or any value greater than or equal to 9223372036854775807.
lead returns the value that is offset records after the current records, and defaultValue if there is less than offset records after the current record. lag returns null value if the number of records in a window partition is less than offset or defaultValue .
Spark >= 3.2
Recent Spark releases provide native support for session windows in both batch and structured streaming queries (see SPARK-10816 and its sub-tasks, especially SPARK-34893).
The official documentation provides nice usage example.
Spark < 3.2
Here is the trick. Import a bunch of functions:
import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{coalesce, datediff, lag, lit, min, sum}
Define windows:
val userWindow = Window.partitionBy("user_name").orderBy("login_date") val userSessionWindow = Window.partitionBy("user_name", "session")
Find the points where new sessions starts:
val newSession = (coalesce( datediff($"login_date", lag($"login_date", 1).over(userWindow)), lit(0) ) > 5).cast("bigint") val sessionized = df.withColumn("session", sum(newSession).over(userWindow))
Find the earliest date per session:
val result = sessionized .withColumn("became_active", min($"login_date").over(userSessionWindow)) .drop("session")
With dataset defined as:
val df = Seq( ("SirChillingtonIV", "2012-01-04"), ("Booooooo99900098", "2012-01-04"), ("Booooooo99900098", "2012-01-06"), ("OprahWinfreyJr", "2012-01-10"), ("SirChillingtonIV", "2012-01-11"), ("SirChillingtonIV", "2012-01-14"), ("SirChillingtonIV", "2012-08-11") ).toDF("user_name", "login_date")
The result is:
+----------------+----------+-------------+ | user_name|login_date|became_active| +----------------+----------+-------------+ | OprahWinfreyJr|2012-01-10| 2012-01-10| |SirChillingtonIV|2012-01-04| 2012-01-04| <- The first session for user |SirChillingtonIV|2012-01-11| 2012-01-11| <- The second session for user |SirChillingtonIV|2012-01-14| 2012-01-11| |SirChillingtonIV|2012-08-11| 2012-08-11| <- The third session for user |Booooooo99900098|2012-01-04| 2012-01-04| |Booooooo99900098|2012-01-06| 2012-01-04| +----------------+----------+-------------+
Refactoring the other answer to work with Pyspark
In Pyspark
you can do like below.
create data frame
df = sqlContext.createDataFrame(
[
("SirChillingtonIV", "2012-01-04"),
("Booooooo99900098", "2012-01-04"),
("Booooooo99900098", "2012-01-06"),
("OprahWinfreyJr", "2012-01-10"),
("SirChillingtonIV", "2012-01-11"),
("SirChillingtonIV", "2012-01-14"),
("SirChillingtonIV", "2012-08-11")
],
("user_name", "login_date"))
The above code creates a data frame like below
+----------------+----------+
| user_name|login_date|
+----------------+----------+
|SirChillingtonIV|2012-01-04|
|Booooooo99900098|2012-01-04|
|Booooooo99900098|2012-01-06|
| OprahWinfreyJr|2012-01-10|
|SirChillingtonIV|2012-01-11|
|SirChillingtonIV|2012-01-14|
|SirChillingtonIV|2012-08-11|
+----------------+----------+
Now we want to first find out the difference between login_date
is more than 5
days.
For this do like below.
Necessary imports
from pyspark.sql import functions as f
from pyspark.sql import Window
# defining window partitions
login_window = Window.partitionBy("user_name").orderBy("login_date")
session_window = Window.partitionBy("user_name", "session")
session_df = df.withColumn("session", f.sum((f.coalesce(f.datediff("login_date", f.lag("login_date", 1).over(login_window)), f.lit(0)) > 5).cast("int")).over(login_window))
When we run the above line of code if the date_diff
is NULL
then the coalesce
function will replace NULL
to 0
.
+----------------+----------+-------+
| user_name|login_date|session|
+----------------+----------+-------+
| OprahWinfreyJr|2012-01-10| 0|
|SirChillingtonIV|2012-01-04| 0|
|SirChillingtonIV|2012-01-11| 1|
|SirChillingtonIV|2012-01-14| 1|
|SirChillingtonIV|2012-08-11| 2|
|Booooooo99900098|2012-01-04| 0|
|Booooooo99900098|2012-01-06| 0|
+----------------+----------+-------+
# add became_active column by finding the `min login_date` for each window partitionBy `user_name` and `session` created in above step
final_df = session_df.withColumn("became_active", f.min("login_date").over(session_window)).drop("session")
+----------------+----------+-------------+
| user_name|login_date|became_active|
+----------------+----------+-------------+
| OprahWinfreyJr|2012-01-10| 2012-01-10|
|SirChillingtonIV|2012-01-04| 2012-01-04|
|SirChillingtonIV|2012-01-11| 2012-01-11|
|SirChillingtonIV|2012-01-14| 2012-01-11|
|SirChillingtonIV|2012-08-11| 2012-08-11|
|Booooooo99900098|2012-01-04| 2012-01-04|
|Booooooo99900098|2012-01-06| 2012-01-04|
+----------------+----------+-------------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With