Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Aggregate while dropping duplicates in pyspark

I want to groupby aggregate a pyspark dataframe, while removing duplicates (keep last value) based on another column of this dataframe.

In summary, I would like to apply a dropDuplicates to a GroupedData object. So, for each group, I could keep only one row by some column, dynamically.

Example

The straight forward group aggregation, for the dataframe bellow, would be:

from pyspark.sql import functions

dataframe = spark.createDataFrame(
    [
        (1, "2020-01-01", 1, 1),
        (2, "2020-01-01", 2, 1),
        (3, "2020-01-02", 1, 1),
        (2, "2020-01-02", 1, 1)
    ],
    ("id", "ts", "feature", "h3")
).withColumn("ts", functions.col("ts").cast("timestamp"))

# +---+-------------------+-------+---+
# | id|                 ts|feature| h3|
# +---+-------------------+-------+---+
# |  1|2020-01-01 00:00:00|      1|  1|
# |  2|2020-01-01 00:00:00|      2|  1|
# |  3|2020-01-02 00:00:00|      1|  1|
# |  2|2020-01-02 00:00:00|      1|  1|
# +---+-------------------+-------+---+

aggregated = dataframe.groupby("h3",
  functions.window(
    timeColumn="ts",
    windowDuration="3 days",
    slideDuration="1 day",
  )
).agg(
  functions.sum("feature")
)
aggregated.show(truncate=False)

resulting in the following dataframe:

+---+------------------------------------------+------------+
|h3 |window                                    |sum(feature)|
+---+------------------------------------------+------------+
|1  |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3           |
|1  |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|5           |
|1  |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|5           |
|1  |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|2           |
+---+------------------------------------------+------------+

The problem

I want the aggregation to use only the latest state of each id. In this case, id=2 have been updated to feature=1 at ts=2020-01-02 00:00:00, so all aggregations with base timestamp bigger than 2020-01-02 00:00:00 should use only this state for column feature when id=2. The expected aggregated dataframe is:

+---+------------------------------------------+------------+
|h3 |window                                    |sum(feature)|
+---+------------------------------------------+------------+
|1  |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3           |
|1  |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|3           |
|1  |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|3           |
|1  |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|2           |
+---+------------------------------------------+------------+

How can I do this with pyspark?

Update

I have assumed that a MapType variable should not have duplicate keys in Spark. With that assumption, I thought I could aggregate the column creating a map id -> feature and then just aggregate the map values with sum (or whatever the final aggregation should be).

So I did:

aggregated = dataframe.groupby("h3",
  functions.window(
    timeColumn="ts",
    windowDuration="3 days",
    slideDuration="1 day",
  )
).agg(
  functions.map_from_entries(
    functions.collect_list(
      functions.struct("id","feature")
    )
  ).alias("id_feature")
)
aggregated.show(truncate=False)

But then I've found that maps can have duplicate keys:

+---+------------------------------------------+--------------------------------+
|h3 |window                                    |id_feature                      |
+---+------------------------------------------+--------------------------------+
|1  |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|[1 -> 1, 2 -> 2, 3 -> 1, 2 -> 1]|
|1  |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|[1 -> 1, 2 -> 2, 3 -> 1, 2 -> 1]|
|1  |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|[1 -> 1, 2 -> 2]                |
|1  |[2020-01-02 00:00:00, 2020-01-05 00:00:00]|[3 -> 1, 2 -> 1]                |
+---+------------------------------------------+--------------------------------+

so it doesn't solve my problem. Instead, I just found another problem. When using the display function in a Databricks' notebook, it shows the MapType column without duplicated keys.

like image 865
Igor Hoelscher Avatar asked Mar 09 '20 21:03

Igor Hoelscher


2 Answers

First, you can find the latest record for each id and time window and then join with the original dataframe with the latest records.

time_window = window(timeColumn="ts", windowDuration="3 days", slideDuration="1 day")

df2 = df.groupBy("h3", time_window, "id").agg(max("ts").alias("latest"))

df2.alias("a").join(df.alias("b"), (col("a.id") == col("b.id")) & (col("a.latest") == col("b.ts")), "left") \
   .select("a.*", "feature") \
   .groupBy("h3", "window") \
   .agg(sum("feature")) \
   .orderBy("window") \
   .show(truncate=False)

Then, the result is the same as your expected one.

+---+------------------------------------------+------------+
|h3 |window                                    |sum(feature)|
+---+------------------------------------------+------------+
|1  |[2019-12-29 00:00:00, 2020-01-01 00:00:00]|3           |
|1  |[2019-12-30 00:00:00, 2020-01-02 00:00:00]|3           |
|1  |[2019-12-31 00:00:00, 2020-01-03 00:00:00]|3           |
|1  |[2020-01-01 00:00:00, 2020-01-04 00:00:00]|2           |
+---+------------------------------------------+------------+
like image 102
Lamanus Avatar answered Oct 09 '22 15:10

Lamanus


Since you are using Spark 2.4+, one way you can try is to use Spark SQL aggregate function, see below:

aggregated = dataframe.groupby("h3",
   functions.window( 
     timeColumn="ts", 
     windowDuration="3 days", 
     slideDuration="1 day", 
   ) 
 ).agg( 
     functions.sort_array(functions.collect_list( 
       functions.struct("ts", "id", "feature") 
     ), False).alias("id_feature") 
 )   

I added ts field into the resulting array of structs from functions.collect_list. use functions.sort_array to sort the list by ts in descending order(to keep the latest record if duplicate exists). In the following aggregate function, we set the zero_value using a named_struct containing two fields: ids (MapType) to cache all processed id and total to do the sum only when the new id not exist in the cached ids.

aggregated.selectExpr("h3", "window", """
  aggregate(
    id_feature,
    /* zero_value */
    (map() as ids, 0L as total), 
    /* merge */
    (acc, y) -> named_struct(
      /* add y.id into the ids map */
      'ids', map_concat(acc.ids, map(y.id,1)), 
      /* sum to total only when y.id doesn't exist in acc.ids map */
      'total', acc.total + IF(acc.ids[y.id] is null,y.feature,0)
    ), 
    /* finish, take only acc.total, discard acc.ids map */
    acc -> acc.total
  ) as id_features

""").show()
+---+--------------------+----------+
| h3|              window|id_feature|
+---+--------------------+----------+
|  1|[2020-01-01 00:00...|         3|
|  1|[2019-12-31 00:00...|         3|
|  1|[2019-12-30 00:00...|         3|
|  1|[2020-01-02 00:00...|         2|
+---+--------------------+----------+
like image 44
jxc Avatar answered Oct 09 '22 14:10

jxc