How to pivot Spark DataFrame?

I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in Python, but I can't find anything for the new Spark Dataframe.

I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. If anyone knows how to do this with built-in functionality or suggestions for how to write something in Scala, it is greatly appreciated.

2 Answers

As mentioned by David Anderson Spark provides pivot function since version 1.6. General syntax looks as follows:

df   .groupBy(grouping_columns)   .pivot(pivot_column, [values])    .agg(aggregate_expressions) 

Usage examples using nycflights13 and csv format:


from pyspark.sql.functions import avg  flights = (sqlContext     .read     .format("csv")     .options(inferSchema="true", header="true")     .load("flights.csv")     .na.drop())  flights.registerTempTable("flights") sqlContext.cacheTable("flights")  gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay")  flights.count() ## 336776  %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop 


val flights = sqlContext   .read   .format("csv")   .options(Map("inferSchema" -> "true", "header" -> "true"))   .load("flights.csv")  flights   .groupBy($"origin", $"dest", $"carrier")   .pivot("hour")   .agg(avg($"arr_delay")) 


import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.*;  Dataset<Row> df = spark.read().format("csv")         .option("inferSchema", "true")         .option("header", "true")         .load("flights.csv");  df.groupBy(col("origin"), col("dest"), col("carrier"))         .pivot("hour")         .agg(avg(col("arr_delay"))); 

R / SparkR:

library(magrittr)  flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)  flights %>%    groupBy("origin", "dest", "carrier") %>%    pivot("hour") %>%    agg(avg(column("arr_delay"))) 

R / sparklyr

library(dplyr)  flights <- spark_read_csv(sc, "flights", "flights.csv")  avg.arr.delay <- function(gdf) {    expr <- invoke_static(       sc,       "org.apache.spark.sql.functions",       "avg",       "arr_delay"     )     gdf %>% invoke("agg", expr, list()) }  flights %>%    sdf_pivot(origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay) 


Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.

CREATE TEMPORARY VIEW flights  USING csv  OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;   SELECT * FROM (    SELECT origin, dest, carrier, arr_delay, hour FROM flights  ) PIVOT (    avg(arr_delay)    FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,                 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)  ); 

Example data:

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00 

Performance considerations:

Generally speaking pivoting is an expensive operation.

  • if you can, try to provide values list, as this avoids an extra hit to compute the uniques:

    vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop 
  • in some cases it proved to be beneficial (likely no longer worth the effort in 2.0 or later) to repartition and / or pre-aggregate the data

  • for reshaping only, you can use first: Pivot String column on Pyspark Dataframe

I overcame this by writing a for loop to dynamically create a SQL query. Say I have:

id  tag  value 1   US    50 1   UK    100 1   Can   125 2   US    75 2   UK    150 2   Can   175 

and I want:

id  US  UK   Can 1   50  100  125 2   75  150  175 

I can create a list with the value I want to pivot and then create a string containing the SQL query I need.

val countries = List("US", "UK", "Can") val numCountries = countries.length - 1  var query = "select *, " for (i <- 0 to numCountries-1) {   query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " } query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"  myDataFrame.registerTempTable("myTable") val myDF1 = sqlContext.sql(query) 

I can create similar query to then do the aggregation. Not a very elegant solution but it works and is flexible for any list of values, which can also be passed in as an argument when your code is called.

