Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to do opposite of explode in PySpark?

Let's say I have a DataFrame with a column for users and another column for words they've written:

Row(user='Bob', word='hello')
Row(user='Bob', word='world')
Row(user='Mary', word='Have')
Row(user='Mary', word='a')
Row(user='Mary', word='nice')
Row(user='Mary', word='day')

I would like to aggregate the word column into a vector:

Row(user='Bob', words=['hello','world'])
Row(user='Mary', words=['Have','a','nice','day'])

It seems I can't use any of Sparks grouping functions because they expect a subsequent aggregation step. My use case is that I want to feed these data into Word2Vec not use other Spark aggregations.

like image 565
Evan Zamir Avatar asked Apr 11 '17 23:04

Evan Zamir


People also ask

What is opposite of explode in Spark?

Requirement is to reverse the Explode operation to convert the string into array values on Spark Dataframe. Code snippet to unit test is given below. test("Reverse-explode operation") { import spark.implicits._

How do you flatten an array in PySpark?

If you want to flatten the arrays, use flatten function which converts array of array columns to a single array on DataFrame.

What does PySpark explode do?

Returns a new row for each element in the given array or map. Uses the default column name col for elements in the array and key and value for elements in the map unless specified otherwise.

Is Spark explode expensive?

The average run time was 0.22 s. It's around 8x faster. For those who are skimming through this post a short summary: Explode is an expensive operation, mostly you can think of some more performance-oriented solution (might not be that easy to do, but will definitely run faster) instead of this standard spark method.


1 Answers

Thanks to @titipat for giving the RDD solution. I did realize shortly after my post that there is actually a DataFrame solution using collect_set (or collect_list):

from pyspark.sql import Row
from pyspark.sql.functions import collect_set
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
                                      Row(user='Bob', word='world'),
                                      Row(user='Mary', word='Have'),
                                      Row(user='Mary', word='a'),
                                      Row(user='Mary', word='nice'),
                                      Row(user='Mary', word='day')])
df = spark.createDataFrame(rdd)
group_user = df.groupBy('user').agg(collect_set('word').alias('words'))
print(group_user.collect())

>[Row(user='Mary', words=['Have', 'nice', 'day', 'a']), Row(user='Bob', words=['world', 'hello'])]
like image 108
Evan Zamir Avatar answered Sep 21 '22 19:09

Evan Zamir