Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to pivot on multiple columns in Spark SQL?

I need to pivot more than one column in a pyspark dataframe. Sample dataframe,

 >>> d = [(100,1,23,10),(100,2,45,11),(100,3,67,12),(100,4,78,13),(101,1,23,10),(101,2,45,13),(101,3,67,14),(101,4,78,15),(102,1,23,10),(102,2,45,11),(102,3,67,16),(102,4,78,18)] >>> mydf = spark.createDataFrame(d,['id','day','price','units']) >>> mydf.show() +---+---+-----+-----+ | id|day|price|units| +---+---+-----+-----+ |100|  1|   23|   10| |100|  2|   45|   11| |100|  3|   67|   12| |100|  4|   78|   13| |101|  1|   23|   10| |101|  2|   45|   13| |101|  3|   67|   14| |101|  4|   78|   15| |102|  1|   23|   10| |102|  2|   45|   11| |102|  3|   67|   16| |102|  4|   78|   18| +---+---+-----+-----+ 

Now,if I need to get price column into a row for each id based on day, then I can use pivot method as,

>>> pvtdf = mydf.withColumn('combcol',F.concat(F.lit('price_'),mydf['day'])).groupby('id').pivot('combcol').agg(F.first('price')) >>> pvtdf.show() +---+-------+-------+-------+-------+ | id|price_1|price_2|price_3|price_4| +---+-------+-------+-------+-------+ |100|     23|     45|     67|     78| |101|     23|     45|     67|     78| |102|     23|     45|     67|     78| +---+-------+-------+-------+-------+ 

so when I need units column as well to be transposed as price, either I got to create one more dataframe as above for units and then join both using id.But, when I have more columns as such, I tried a function to do it,

>>> def pivot_udf(df,*cols): ...     mydf = df.select('id').drop_duplicates() ...     for c in cols: ...        mydf = mydf.join(df.withColumn('combcol',F.concat(F.lit('{}_'.format(c)),df['day'])).groupby('id').pivot('combcol').agg(F.first(c)),'id') ...     return mydf ... >>> pivot_udf(mydf,'price','units').show() +---+-------+-------+-------+-------+-------+-------+-------+-------+ | id|price_1|price_2|price_3|price_4|units_1|units_2|units_3|units_4| +---+-------+-------+-------+-------+-------+-------+-------+-------+ |100|     23|     45|     67|     78|     10|     11|     12|     13| |101|     23|     45|     67|     78|     10|     13|     14|     15| |102|     23|     45|     67|     78|     10|     11|     16|     18| +---+-------+-------+-------+-------+-------+-------+-------+-------+ 

Need suggestions on ,if it is good practice to do so and if any other better way of doing it. Thanks in advance!

like image 709
Suresh Avatar asked Jul 11 '17 13:07

Suresh


People also ask

How do I pivot multiple columns in Spark?

Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).

Can you pivot on multiple columns?

You can create multiple columns or rows in a pivot table to handle multiple descriptions.

Can you pivot multiple columns in SQL Server?

You can use the SQL Pivot statement to transpose multiple columns.

How do you pivot in Spark?

To do the same group/pivot/sum in Spark the syntax is df. groupBy("A", "B"). pivot("C").


2 Answers

Here's a non-UDF way involving a single pivot (hence, just a single column scan to identify all the unique dates).

dff = mydf.groupBy('id').pivot('day').agg(F.first('price').alias('price'),F.first('units').alias('unit')) 

Here's the result (apologies for the non-matching ordering and naming):

+---+-------+------+-------+------+-------+------+-------+------+                | id|1_price|1_unit|2_price|2_unit|3_price|3_unit|4_price|4_unit| +---+-------+------+-------+------+-------+------+-------+------+ |100|     23|    10|     45|    11|     67|    12|     78|    13| |101|     23|    10|     45|    13|     67|    14|     78|    15| |102|     23|    10|     45|    11|     67|    16|     78|    18| +---+-------+------+-------+------+-------+------+-------+------+ 

We just aggregate both on the price and the unit column after pivoting on the day.

If naming required as in question,

dff.select([F.col(c).name('_'.join(x for x in c.split('_')[::-1])) for c in dff.columns]).show()  +---+-------+------+-------+------+-------+------+-------+------+ | id|price_1|unit_1|price_2|unit_2|price_3|unit_3|price_4|unit_4| +---+-------+------+-------+------+-------+------+-------+------+ |100|     23|    10|     45|    11|     67|    12|     78|    13| |101|     23|    10|     45|    13|     67|    14|     78|    15| |102|     23|    10|     45|    11|     67|    16|     78|    18| +---+-------+------+-------+------+-------+------+-------+------+ 
like image 104
Jedi Avatar answered Sep 22 '22 12:09

Jedi


The solution in the question is the best I could get. The only improvement would be to cache the input dataset to avoid double scan, i.e.

mydf.cache pivot_udf(mydf,'price','units').show() 
like image 33
Jacek Laskowski Avatar answered Sep 24 '22 12:09

Jacek Laskowski