Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark window function with condition

Suppose I have a DataFrame of events with time difference between each row, the main rule is that one visit is counted if only the event has been within 5 minutes of the previous or next event:

+--------+-------------------+--------+
|userid  |eventtime          |timeDiff|
+--------+-------------------+--------+
|37397e29|2017-06-04 03:00:00|60      |
|37397e29|2017-06-04 03:01:00|60      |
|37397e29|2017-06-04 03:02:00|60      |
|37397e29|2017-06-04 03:03:00|180     |
|37397e29|2017-06-04 03:06:00|60      |
|37397e29|2017-06-04 03:07:00|420     |
|37397e29|2017-06-04 03:14:00|60      |
|37397e29|2017-06-04 03:15:00|1140    |
|37397e29|2017-06-04 03:34:00|540     |
|37397e29|2017-06-04 03:53:00|540     |
+--------+----------------- -+--------+

The challenge is to group by the start_time and end_time of the latest eventtime that has the condition of being within 5 minutes. The output should be like this table:

+--------+-------------------+--------------------+-----------+
|userid  |start_time         |end_time            |events     |
+--------+-------------------+--------------------+-----------+
|37397e29|2017-06-04 03:00:00|2017-06-04 03:07:00 |6          |
|37397e29|2017-06-04 03:14:00|2017-06-04 03:15:00 |2          |
+--------+-------------------+--------------------+-----------+

So far I have used window lag functions and some conditions, however, I do not know where to go from here:

%spark.pyspark
from pyspark.sql import functions as F
from pyspark.sql import Window as W
from pyspark.sql.functions import col

windowSpec = W.partitionBy(result_poi["userid"], result_poi["unique_reference_number"]).orderBy(result_poi["eventtime"])
windowSpecDesc = W.partitionBy(result_poi["userid"], result_poi["unique_reference_number"]).orderBy(result_poi["eventtime"].desc())

# The windows are between the current row and following row. e.g: 3:00pm and 3:03pm 
nextEventTime = F.lag(col("eventtime"), -1).over(windowSpec)

# The windows are between the current row and following row. e.g: 3:00pm and 3:03pm 
previousEventTime = F.lag(col("eventtime"), 1).over(windowSpec)
diffEventTime = nextEventTime - col("eventtime")

nextTimeDiff = F.coalesce((F.unix_timestamp(nextEventTime)
            - F.unix_timestamp('eventtime')), F.lit(0))
previousTimeDiff = F.coalesce((F.unix_timestamp('eventtime') -F.unix_timestamp(previousEventTime)), F.lit(0))


# Check if the next POI is the equal to the current POI and has a time differnce less than 5 minutes.
validation = F.coalesce(( (nextTimeDiff < 300) | (previousTimeDiff < 300) ), F.lit(False))

# Change True to 1
visitCheck = F.coalesce((validation == True).cast("int"), F.lit(1))


result_poi.withColumn("visit_check", visitCheck).withColumn("nextTimeDiff", nextTimeDiff).select("userid", "eventtime", "nextTimeDiff", "visit_check").orderBy("eventtime")

My questions: Is this a viable approach, and if so, how can I "go forward" and look at the maximum eventtime that fulfill the 5 minutes condition. To my knowledge, iterate through values of a Spark SQL Column, is it possible? wouldn't it be too expensive?. Is there another way to achieve this result?

Result of Solution suggested by @Aku:

+--------+--------+---------------------+---------------------+------+
|userid  |subgroup|start_time           |end_time             |events|
+--------+--------+--------+------------+---------------------+------+
|37397e29|0       |2017-06-04 03:00:00.0|2017-06-04 03:06:00.0|5     |
|37397e29|1       |2017-06-04 03:07:00.0|2017-06-04 03:14:00.0|2     |
|37397e29|2       |2017-06-04 03:15:00.0|2017-06-04 03:15:00.0|1     |
|37397e29|3       |2017-06-04 03:34:00.0|2017-06-04 03:43:00.0|2     |
+------------------------------------+-----------------------+-------+

It doesn't give the result expected. 3:07 - 3:14 and 03:34-03:43 are being counted as ranges within 5 minutes, it shouldn't be like that. Also, 3:07 should be the end_time in the first row as it is within 5 minutes of the previous row 3:06.

like image 609
ebertbm Avatar asked Aug 17 '17 13:08

ebertbm


2 Answers

You'll need one extra window function and a groupby to achieve this. What we want is for every line with timeDiff greater than 300 to be the end of a group and the start of a new one. Aku's solution should work, only the indicators mark the start of a group instead of the end. To change this you'll have to do a cumulative sum up to n-1 instead of n (n being your current line):

w = Window.partitionBy("userid").orderBy("eventtime")
DF = DF.withColumn("indicator", (DF.timeDiff > 300).cast("int"))
DF = DF.withColumn("subgroup", func.sum("indicator").over(w) - func.col("indicator"))
DF = DF.groupBy("subgroup").agg(
    func.min("eventtime").alias("start_time"), 
    func.max("eventtime").alias("end_time"),
    func.count("*").alias("events")
 )

+--------+-------------------+-------------------+------+
|subgroup|         start_time|           end_time|events|
+--------+-------------------+-------------------+------+
|       0|2017-06-04 03:00:00|2017-06-04 03:07:00|     6|
|       1|2017-06-04 03:14:00|2017-06-04 03:15:00|     2|
|       2|2017-06-04 03:34:00|2017-06-04 03:34:00|     1|
|       3|2017-06-04 03:53:00|2017-06-04 03:53:00|     1|
+--------+-------------------+-------------------+------+

It seems that you also filter out lines with only one event, hence:

DF = DF.filter("events != 1")

+--------+-------------------+-------------------+------+
|subgroup|         start_time|           end_time|events|
+--------+-------------------+-------------------+------+
|       0|2017-06-04 03:00:00|2017-06-04 03:07:00|     6|
|       1|2017-06-04 03:14:00|2017-06-04 03:15:00|     2|
+--------+-------------------+-------------------+------+
like image 187
MaFF Avatar answered Sep 22 '22 16:09

MaFF


So if I understand this correctly you essentially want to end each group when TimeDiff > 300? This seems relatively straightforward with rolling window functions:

First some imports

from pyspark.sql.window import Window
import pyspark.sql.functions as func

Then setting windows, I assumed you would partition by userid

w = Window.partitionBy("userid").orderBy("eventtime")

Then figuring out what subgroup each observation falls into, by first marking the first member of each group, then summing the column.

indicator = (TimeDiff > 300).cast("integer")
subgroup = func.sum(indicator).over(w).alias("subgroup")

Then some aggregation functions and you should be done

DF = DF.select("*", subgroup)\
.groupBy("subgroup")\
.agg(
    func.min("eventtime").alias("start_time"),
    func.max("eventtime").alias("end_time"),
    func.count(func.lit(1)).alias("events")
)
like image 40
aku Avatar answered Sep 24 '22 16:09

aku