A (Python) example will make my question clear. Let's say I have a Spark dataframe of people who watched certain movies on certain dates, as follows:
movierecord = spark.createDataFrame([("Alice", 1, ["Avatar"]),("Bob", 2, ["Fargo", "Tron"]),("Alice", 4, ["Babe"]), ("Alice", 6, ["Avatar", "Airplane"]), ("Alice", 7, ["Pulp Fiction"]), ("Bob", 9, ["Star Wars"])],["name","unixdate","movies"])
The schema and the dataframe defined by the above look as follows:
root
|-- name: string (nullable = true)
|-- unixdate: long (nullable = true)
|-- movies: array (nullable = true)
| |-- element: string (containsNull = true)
+-----+--------+------------------+
|name |unixdate|movies |
+-----+--------+------------------+
|Alice|1 |[Avatar] |
|Bob |2 |[Fargo, Tron] |
|Alice|4 |[Babe] |
|Alice|6 |[Avatar, Airplane]|
|Alice|7 |[Pulp Fiction] |
|Bob |9 |[Star Wars] |
+-----+--------+------------------+
I'd like to go from the above to generating a new dataframe column which holds all previous movies seen by each user, without duplicates ("previous" per the unixdate field). So it should look like this:
+-----+--------+------------------+------------------------+
|name |unixdate|movies |previous_movies |
+-----+--------+------------------+------------------------+
|Alice|1 |[Avatar] |[] |
|Bob |2 |[Fargo, Tron] |[] |
|Alice|4 |[Babe] |[Avatar] |
|Alice|6 |[Avatar, Airplane]|[Avatar, Babe] |
|Alice|7 |[Pulp Fiction] |[Avatar, Babe, Airplane]|
|Bob |9 |[Star Wars] |[Fargo, Tron] |
+-----+--------+------------------+------------------------+
How do I implement this in a nice efficient way?
To split multiple array column data into rows pyspark provides a function called explode(). Using explode, we will get a new row for each element in the array.
In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).
Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.
SQL only without preserving the order of objects:
Required imports:
import pyspark.sql.functions as f
from pyspark.sql.window import Window
Window definition:
w = Window.partitionBy("name").orderBy("unixdate")
Complete solution:
(movierecord
# Flatten movies
.withColumn("previous_movie", f.explode("movies"))
# Collect unique
.withColumn("previous_movies", f.collect_set("previous_movie").over(w))
# Drop duplicates for a single unixdate
.groupBy("name", "unixdate")
.agg(f.max(f.struct(
f.size("previous_movies"),
f.col("movies").alias("movies"),
f.col("previous_movies").alias("previous_movies")
)).alias("tmp"))
# Shift by one and extract
.select(
"name", "unixdate", "tmp.movies",
f.lag("tmp.previous_movies", 1).over(w).alias("previous_movies")))
The result:
+-----+--------+------------------+------------------------+
|name |unixdate|movies |previous_movies |
+-----+--------+------------------+------------------------+
|Bob |2 |[Fargo, Tron] |null |
|Bob |9 |[Star Wars] |[Fargo, Tron] |
|Alice|1 |[Avatar] |null |
|Alice|4 |[Babe] |[Avatar] |
|Alice|6 |[Avatar, Airplane]|[Babe, Avatar] |
|Alice|7 |[Pulp Fiction] |[Babe, Airplane, Avatar]|
+-----+--------+------------------+------------------------+
SQL an Python UDF preserving the order:
Imports:
import pyspark.sql.functions as f
from pyspark.sql.window import Window
from pyspark.sql import Column
from pyspark.sql.types import ArrayType, StringType
from typing import List, Union
# https://github.com/pytoolz/toolz
from toolz import unique, concat, compose
UDF:
def flatten_distinct(col: Union[Column, str]) -> Column:
def flatten_distinct_(xss: Union[List[List[str]], None]) -> List[str]:
return compose(list, unique, concat)(xss or [])
return f.udf(flatten_distinct_, ArrayType(StringType()))(col)
Window definition as before.
Complete solution:
(movierecord
# Collect lists
.withColumn("previous_movies", f.collect_list("movies").over(w))
# Flatten and drop duplicates
.withColumn("previous_movies", flatten_distinct("previous_movies"))
# Shift by one
.withColumn("previous_movies", f.lag("previous_movies", 1).over(w))
# For presentation only
.orderBy("unixdate"))
The result:
+-----+--------+------------------+------------------------+
|name |unixdate|movies |previous_movies |
+-----+--------+------------------+------------------------+
|Alice|1 |[Avatar] |null |
|Bob |2 |[Fargo, Tron] |null |
|Alice|4 |[Babe] |[Avatar] |
|Alice|6 |[Avatar, Airplane]|[Avatar, Babe] |
|Alice|7 |[Pulp Fiction] |[Avatar, Babe, Airplane]|
|Bob |9 |[Star Wars] |[Fargo, Tron] |
+-----+--------+------------------+------------------------+
Performance:
I believe there is no efficient way to solve this given the constraints. Not only requested output requires a significant data duplication (data is binary encoded to fit Tungsten format, so you get possible compression but loose object identity) but also a number of operations which are expensive given Spark computing model including expensive grouping and sorting.
This should be fine if expect size of the previous_movies
is bounded and small but won't be feasible in general.
Data duplication is pretty easy to address by keeping single, lazy history for an user. Not something that can be done in SQL but quite easy with low level RDD operations.
Explode and collect_
pattern is expensive. If your requirements are strict but you want to improve performance you can use Scala UDF in place of Python one.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With