Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Iterating over PySpark GroupedData

Lets assume that original data is like:

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp1       B       £11       £16
Comp1       C       £11       £15
Comp2       A       £9        £16
Comp2       B       £12       £14
Comp2       C       £14       £17
Comp3       A       £11       £16
Comp3       B       £10       £15
Comp3       C       £12       £15

(Ref: Python - splitting dataframe into multiple dataframes based on column values and naming them with those values)

I wish to get list of sub dataframes based on column values, say Region, like:

df_A :

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp2       A       £9        £16
Comp3       A       £11       £16

In Python I could do:

for region, df_region in df.groupby('Region'):
    print(df_region)

Can I do same iteration if the df is Pyspark df?

In Pyspark, once I do df.groupBy("Region") I get GroupedData. I dont need any aggregation like count, mean, etc. I just need list of sub dataframes, each have same "Region" value. Possible?

like image 660
Yogesh Kulkarni Avatar asked Jul 23 '18 05:07

Yogesh Kulkarni


People also ask

How do I iterate over a PySpark column?

For looping through each row using map() first we have to convert the PySpark dataframe into RDD because map() is performed on RDD's only, so first convert into RDD it then use map() in which, lambda function for iterating through each row and stores the new RDD in some variable then convert back that new RDD into ...

How do you iterate over rows in PySpark?

For looping through each row using map() first we have to convert the PySpark dataframe into RDD because map() is performed on RDD's only, so first convert into RDD it then use map() in which, lambda function for iterating through each row and stores the new RDD in some variable then convert back that new RDD into ...

How do you use a foreach in PySpark?

Example of PySpark foreachLet's first create a DataFrame in Python. Now let's create a simple function first that will print all the elements in and will pass it in a For Each Loop. This is a simple Print function that prints all the data in a DataFrame. Let's iterate over all the elements using for Each loop.

What does .collect do in PySpark?

Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.


1 Answers

The approach below should work for you, under the assumption that the list of unique values in the grouping column is small enough to fit in memory on the driver. Hope this helps!

import pyspark.sql.functions as F
import pandas as pd

# Sample data 
df = pd.DataFrame({'region': ['aa','aa','aa','bb','bb','cc'],
                   'x2': [6,5,4,3,2,1],
                   'x3': [1,2,3,4,5,6]})
df = spark.createDataFrame(df)

# Get unique values in the grouping column
groups = [x[0] for x in df.select("region").distinct().collect()]

# Create a filtered DataFrame for each group in a list comprehension
groups_list = [df.filter(F.col('region')==x) for x in groups]

# show the results
[x.show() for x in groups_list]

Result:

+------+---+---+
|region| x2| x3|
+------+---+---+
|    cc|  1|  6|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    bb|  3|  4|
|    bb|  2|  5|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    aa|  6|  1|
|    aa|  5|  2|
|    aa|  4|  3|
+------+---+---+
like image 164
Florian Avatar answered Oct 12 '22 22:10

Florian