Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark Dataframe melt columns into rows

As the subject describes, I have a PySpark Dataframe that I need to melt three columns into rows. Each column essentially represents a single fact in a category. The ultimate goal is to aggregate the data into a single total per category.

There are tens of millions of rows in this dataframe, so I need a way to do the transformation on the spark cluster without bringing back any data to the driver (Jupyter in this case).

Here is an extract of my dataframe for just a few stores: +-----------+----------------+-----------------+----------------+ | store_id |qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs| +-----------+----------------+-----------------+----------------+ | 100| 30| 105| 35| | 200| 55| 85| 65| | 300| 20| 125| 90| +-----------+----------------+-----------------+----------------+

Here is the desired resulting dataframe, multiple rows per store, where the columns of the original dataframe have been melted into rows of the new dataframe, with one row per original column in a new category column: +-----------+--------+-----------+ | product_id|CATEGORY|qty_on_hand| +-----------+--------+-----------+ | 100| milk| 30| | 100| bread| 105| | 100| eggs| 35| | 200| milk| 55| | 200| bread| 85| | 200| eggs| 65| | 300| milk| 20| | 300| bread| 125| | 300| eggs| 90| +-----------+--------+-----------+

Ultimately, I want to aggregate the resulting dataframe to get the totals per category: +--------+-----------------+ |CATEGORY|total_qty_on_hand| +--------+-----------------+ | milk| 105| | bread| 315| | eggs| 190| +--------+-----------------+

UPDATE: There is a suggestion that this question is a duplicate and can be answered here. This is not the case, as the solution casts rows to columns and I need to do the reverse, melt columns into rows.

like image 454
Gary C Avatar asked Mar 27 '19 13:03

Gary C


People also ask

How do you transpose columns to rows in PySpark DataFrame?

Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).

How do you transpose in PySpark DataFrame?

PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.

What is melt function in PySpark?

Melt (also known as unpivot) In other words, the data frame is converted from wide to long format. The starting data frame to demonstrate this recipe can be constructed with. There are likely several ways to implement a melt function in PySpark.


1 Answers

We can use explode() function to solve this issue. In Python, the same thing can be done with melt

# Loading the requisite packages 
from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit        
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
|     100|              30|              105|              35|
|     200|              55|               85|              65|
|     300|              20|              125|              90|
+--------+----------------+-----------------+----------------+

Writing the function below, which shall explode this DataFrame:

def to_explode(df, by):

    # Filter dtypes and split into column names and type description
    cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
    # Spark SQL supports only homogeneous columns
    assert len(set(dtypes)) == 1, "All columns have to be of the same type"

    # Create and explode an array of (column_name, column_value) structs
    kvs = explode(array([
      struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
    ])).alias("kvs")

    return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])

Applying the function on this DataFrame to explode it-

df = to_explode(df, ['store_id'])\
     .drop('store_id')
df.show()
+-----------------+-----------+
|         CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk|         30|
|qty_on_hand_bread|        105|
| qty_on_hand_eggs|         35|
| qty_on_hand_milk|         55|
|qty_on_hand_bread|         85|
| qty_on_hand_eggs|         65|
| qty_on_hand_milk|         20|
|qty_on_hand_bread|        125|
| qty_on_hand_eggs|         90|
+-----------------+-----------+

Now, we need to remove the string qty_on_hand_ from CATEGORY column. It can be done using expr() function. Note expr follows 1 based indexing for the substring, as opposed to 0 -

df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
|    milk|         30|
|   bread|        105|
|    eggs|         35|
|    milk|         55|
|   bread|         85|
|    eggs|         65|
|    milk|         20|
|   bread|        125|
|    eggs|         90|
+--------+-----------+

Finally, aggregating the column qty_on_hand grouped by CATEGORY using agg() function -

df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
|    eggs|              190|
|   bread|              315|
|    milk|              105|
+--------+-----------------+
like image 103
cph_sto Avatar answered Sep 24 '22 19:09

cph_sto