Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cumulate arrays from earlier rows (PySpark dataframe)

A (Python) example will make my question clear. Let's say I have a Spark dataframe of people who watched certain movies on certain dates, as follows:

movierecord = spark.createDataFrame([("Alice", 1, ["Avatar"]),("Bob", 2, ["Fargo", "Tron"]),("Alice", 4, ["Babe"]), ("Alice", 6, ["Avatar", "Airplane"]), ("Alice", 7, ["Pulp Fiction"]), ("Bob", 9, ["Star Wars"])],["name","unixdate","movies"])

The schema and the dataframe defined by the above look as follows:

root
 |-- name: string (nullable = true)
 |-- unixdate: long (nullable = true)
 |-- movies: array (nullable = true)
 |    |-- element: string (containsNull = true)

+-----+--------+------------------+
|name |unixdate|movies            |
+-----+--------+------------------+
|Alice|1       |[Avatar]          |
|Bob  |2       |[Fargo, Tron]     |
|Alice|4       |[Babe]            |
|Alice|6       |[Avatar, Airplane]|
|Alice|7       |[Pulp Fiction]    |
|Bob  |9       |[Star Wars]       |
+-----+--------+------------------+

I'd like to go from the above to generating a new dataframe column which holds all previous movies seen by each user, without duplicates ("previous" per the unixdate field). So it should look like this:

+-----+--------+------------------+------------------------+
|name |unixdate|movies            |previous_movies         |
+-----+--------+------------------+------------------------+
|Alice|1       |[Avatar]          |[]                      |
|Bob  |2       |[Fargo, Tron]     |[]                      |
|Alice|4       |[Babe]            |[Avatar]                |
|Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
|Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
|Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
+-----+--------+------------------+------------------------+

How do I implement this in a nice efficient way?

like image 659
xenocyon Avatar asked Jan 05 '17 02:01

xenocyon


People also ask

How do you explode rows in PySpark?

To split multiple array column data into rows pyspark provides a function called explode(). Using explode, we will get a new row for each element in the array.

How do you select the first 10 rows in PySpark?

In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).

What does .collect do in PySpark?

Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.


1 Answers

SQL only without preserving the order of objects:

  • Required imports:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    
  • Window definition:

    w = Window.partitionBy("name").orderBy("unixdate")
    
  • Complete solution:

    (movierecord
        # Flatten movies
        .withColumn("previous_movie", f.explode("movies"))
        # Collect unique
        .withColumn("previous_movies", f.collect_set("previous_movie").over(w))
        # Drop duplicates for a single unixdate
        .groupBy("name", "unixdate")
        .agg(f.max(f.struct(
            f.size("previous_movies"),
            f.col("movies").alias("movies"),
            f.col("previous_movies").alias("previous_movies")
        )).alias("tmp"))
        # Shift by one and extract
       .select(
           "name", "unixdate", "tmp.movies", 
           f.lag("tmp.previous_movies", 1).over(w).alias("previous_movies")))
    
  • The result:

     +-----+--------+------------------+------------------------+
     |name |unixdate|movies            |previous_movies         |
     +-----+--------+------------------+------------------------+
     |Bob  |2       |[Fargo, Tron]     |null                    |
     |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
     |Alice|1       |[Avatar]          |null                    |
     |Alice|4       |[Babe]            |[Avatar]                |
     |Alice|6       |[Avatar, Airplane]|[Babe, Avatar]          |
     |Alice|7       |[Pulp Fiction]    |[Babe, Airplane, Avatar]|
     +-----+--------+------------------+------------------------+
    

SQL an Python UDF preserving the order:

  • Imports:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    from pyspark.sql import Column
    from pyspark.sql.types import ArrayType, StringType
    
    from typing import List, Union
    
    # https://github.com/pytoolz/toolz
    from toolz import unique, concat, compose
    
  • UDF:

    def flatten_distinct(col: Union[Column, str]) -> Column:
        def flatten_distinct_(xss: Union[List[List[str]], None]) -> List[str]:
            return compose(list, unique, concat)(xss or [])
        return f.udf(flatten_distinct_, ArrayType(StringType()))(col)
    
  • Window definition as before.

  • Complete solution:

    (movierecord
        # Collect lists
        .withColumn("previous_movies", f.collect_list("movies").over(w))
        # Flatten and drop duplicates
        .withColumn("previous_movies", flatten_distinct("previous_movies"))
        # Shift by one
        .withColumn("previous_movies", f.lag("previous_movies", 1).over(w))
        # For presentation only
        .orderBy("unixdate")) 
    
  • The result:

    +-----+--------+------------------+------------------------+
    |name |unixdate|movies            |previous_movies         |
    +-----+--------+------------------+------------------------+
    |Alice|1       |[Avatar]          |null                    |
    |Bob  |2       |[Fargo, Tron]     |null                    |
    |Alice|4       |[Babe]            |[Avatar]                |
    |Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
    |Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
    |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
    +-----+--------+------------------+------------------------+
    

Performance:

I believe there is no efficient way to solve this given the constraints. Not only requested output requires a significant data duplication (data is binary encoded to fit Tungsten format, so you get possible compression but loose object identity) but also a number of operations which are expensive given Spark computing model including expensive grouping and sorting.

This should be fine if expect size of the previous_movies is bounded and small but won't be feasible in general.

Data duplication is pretty easy to address by keeping single, lazy history for an user. Not something that can be done in SQL but quite easy with low level RDD operations.

Explode and collect_ pattern is expensive. If your requirements are strict but you want to improve performance you can use Scala UDF in place of Python one.

like image 164
zero323 Avatar answered Oct 20 '22 01:10

zero323