How to pivot on multiple columns in Spark SQL?

I need to pivot more than one column in a pyspark dataframe. Sample dataframe,

 >>> d = [(100,1,23,10),(100,2,45,11),(100,3,67,12),(100,4,78,13),(101,1,23,10),(101,2,45,13),(101,3,67,14),(101,4,78,15),(102,1,23,10),(102,2,45,11),(102,3,67,16),(102,4,78,18)] >>> mydf = spark.createDataFrame(d,['id','day','price','units']) >>> mydf.show() +---+---+-----+-----+ | id|day|price|units| +---+---+-----+-----+ |100|  1|   23|   10| |100|  2|   45|   11| |100|  3|   67|   12| |100|  4|   78|   13| |101|  1|   23|   10| |101|  2|   45|   13| |101|  3|   67|   14| |101|  4|   78|   15| |102|  1|   23|   10| |102|  2|   45|   11| |102|  3|   67|   16| |102|  4|   78|   18| +---+---+-----+-----+ 

Now,if I need to get price column into a row for each id based on day, then I can use pivot method as,

>>> pvtdf = mydf.withColumn('combcol',F.concat(F.lit('price_'),mydf['day'])).groupby('id').pivot('combcol').agg(F.first('price')) >>> pvtdf.show() +---+-------+-------+-------+-------+ | id|price_1|price_2|price_3|price_4| +---+-------+-------+-------+-------+ |100|     23|     45|     67|     78| |101|     23|     45|     67|     78| |102|     23|     45|     67|     78| +---+-------+-------+-------+-------+ 

so when I need units column as well to be transposed as price, either I got to create one more dataframe as above for units and then join both using id.But, when I have more columns as such, I tried a function to do it,

>>> def pivot_udf(df,*cols): ...     mydf = df.select('id').drop_duplicates() ...     for c in cols: ...        mydf = mydf.join(df.withColumn('combcol',F.concat(F.lit('{}_'.format(c)),df['day'])).groupby('id').pivot('combcol').agg(F.first(c)),'id') ...     return mydf ... >>> pivot_udf(mydf,'price','units').show() +---+-------+-------+-------+-------+-------+-------+-------+-------+ | id|price_1|price_2|price_3|price_4|units_1|units_2|units_3|units_4| +---+-------+-------+-------+-------+-------+-------+-------+-------+ |100|     23|     45|     67|     78|     10|     11|     12|     13| |101|     23|     45|     67|     78|     10|     13|     14|     15| |102|     23|     45|     67|     78|     10|     11|     16|     18| +---+-------+-------+-------+-------+-------+-------+-------+-------+ 

Need suggestions on ,if it is good practice to do so and if any other better way of doing it. Thanks in advance!

2 Answers

Here's a non-UDF way involving a single pivot (hence, just a single column scan to identify all the unique dates).

dff = mydf.groupBy('id').pivot('day').agg(F.first('price').alias('price'),F.first('units').alias('unit')) 

Here's the result (apologies for the non-matching ordering and naming):

+---+-------+------+-------+------+-------+------+-------+------+                | id|1_price|1_unit|2_price|2_unit|3_price|3_unit|4_price|4_unit| +---+-------+------+-------+------+-------+------+-------+------+ |100|     23|    10|     45|    11|     67|    12|     78|    13| |101|     23|    10|     45|    13|     67|    14|     78|    15| |102|     23|    10|     45|    11|     67|    16|     78|    18| +---+-------+------+-------+------+-------+------+-------+------+ 

We just aggregate both on the price and the unit column after pivoting on the day.

If naming required as in question,

dff.select([F.col(c).name('_'.join(x for x in c.split('_')[::-1])) for c in dff.columns]).show()  +---+-------+------+-------+------+-------+------+-------+------+ | id|price_1|unit_1|price_2|unit_2|price_3|unit_3|price_4|unit_4| +---+-------+------+-------+------+-------+------+-------+------+ |100|     23|    10|     45|    11|     67|    12|     78|    13| |101|     23|    10|     45|    13|     67|    14|     78|    15| |102|     23|    10|     45|    11|     67|    16|     78|    18| +---+-------+------+-------+------+-------+------+-------+------+ 
The solution in the question is the best I could get. The only improvement would be to cache the input dataset to avoid double scan, i.e.

mydf.cache pivot_udf(mydf,'price','units').show() 
Jacek Laskowski