As the subject describes, I have a PySpark Dataframe that I need to melt three columns into rows. Each column essentially represents a single fact in a category. The ultimate goal is to aggregate the data into a single total per category.
There are tens of millions of rows in this dataframe, so I need a way to do the transformation on the spark cluster without bringing back any data to the driver (Jupyter in this case).
Here is an extract of my dataframe for just a few stores:
+-----------+----------------+-----------------+----------------+
| store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+-----------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+-----------+----------------+-----------------+----------------+
Here is the desired resulting dataframe, multiple rows per store, where the columns of the original dataframe have been melted into rows of the new dataframe, with one row per original column in a new category column:
+-----------+--------+-----------+
| product_id|CATEGORY|qty_on_hand|
+-----------+--------+-----------+
| 100| milk| 30|
| 100| bread| 105|
| 100| eggs| 35|
| 200| milk| 55|
| 200| bread| 85|
| 200| eggs| 65|
| 300| milk| 20|
| 300| bread| 125|
| 300| eggs| 90|
+-----------+--------+-----------+
Ultimately, I want to aggregate the resulting dataframe to get the totals per category:
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| milk| 105|
| bread| 315|
| eggs| 190|
+--------+-----------------+
UPDATE: There is a suggestion that this question is a duplicate and can be answered here. This is not the case, as the solution casts rows to columns and I need to do the reverse, melt columns into rows.
Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).
PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.
Melt (also known as unpivot) In other words, the data frame is converted from wide to long format. The starting data frame to demonstrate this recipe can be constructed with. There are likely several ways to implement a melt function in PySpark.
We can use explode() function to solve this issue. In Python, the same thing can be done with melt
# Loading the requisite packages
from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+--------+----------------+-----------------+----------------+
Writing the function below, which shall explode
this DataFrame:
def to_explode(df, by):
# Filter dtypes and split into column names and type description
cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
# Spark SQL supports only homogeneous columns
assert len(set(dtypes)) == 1, "All columns have to be of the same type"
# Create and explode an array of (column_name, column_value) structs
kvs = explode(array([
struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
])).alias("kvs")
return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])
Applying the function on this DataFrame to explode
it-
df = to_explode(df, ['store_id'])\
.drop('store_id')
df.show()
+-----------------+-----------+
| CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk| 30|
|qty_on_hand_bread| 105|
| qty_on_hand_eggs| 35|
| qty_on_hand_milk| 55|
|qty_on_hand_bread| 85|
| qty_on_hand_eggs| 65|
| qty_on_hand_milk| 20|
|qty_on_hand_bread| 125|
| qty_on_hand_eggs| 90|
+-----------------+-----------+
Now, we need to remove the string qty_on_hand_
from CATEGORY
column. It can be done using expr() function. Note expr
follows 1 based indexing for the substring, as opposed to 0 -
df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
| milk| 30|
| bread| 105|
| eggs| 35|
| milk| 55|
| bread| 85|
| eggs| 65|
| milk| 20|
| bread| 125|
| eggs| 90|
+--------+-----------+
Finally, aggregating the column qty_on_hand
grouped by CATEGORY
using agg() function -
df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| eggs| 190|
| bread| 315|
| milk| 105|
+--------+-----------------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With