Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to pivot Spark DataFrame?

I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in Python, but I can't find anything for the new Spark Dataframe.

I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. If anyone knows how to do this with built-in functionality or suggestions for how to write something in Scala, it is greatly appreciated.

like image 874
J Calbreath Avatar asked May 14 '15 18:05

J Calbreath


People also ask

How do you pivot and Unpivot in PySpark?

PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.

How do I convert columns to rows in spark?

Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).


2 Answers

As mentioned by David Anderson Spark provides pivot function since version 1.6. General syntax looks as follows:

df   .groupBy(grouping_columns)   .pivot(pivot_column, [values])    .agg(aggregate_expressions) 

Usage examples using nycflights13 and csv format:

Python:

from pyspark.sql.functions import avg  flights = (sqlContext     .read     .format("csv")     .options(inferSchema="true", header="true")     .load("flights.csv")     .na.drop())  flights.registerTempTable("flights") sqlContext.cacheTable("flights")  gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay")  flights.count() ## 336776  %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop 

Scala:

val flights = sqlContext   .read   .format("csv")   .options(Map("inferSchema" -> "true", "header" -> "true"))   .load("flights.csv")  flights   .groupBy($"origin", $"dest", $"carrier")   .pivot("hour")   .agg(avg($"arr_delay")) 

Java:

import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.*;  Dataset<Row> df = spark.read().format("csv")         .option("inferSchema", "true")         .option("header", "true")         .load("flights.csv");  df.groupBy(col("origin"), col("dest"), col("carrier"))         .pivot("hour")         .agg(avg(col("arr_delay"))); 

R / SparkR:

library(magrittr)  flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)  flights %>%    groupBy("origin", "dest", "carrier") %>%    pivot("hour") %>%    agg(avg(column("arr_delay"))) 

R / sparklyr

library(dplyr)  flights <- spark_read_csv(sc, "flights", "flights.csv")  avg.arr.delay <- function(gdf) {    expr <- invoke_static(       sc,       "org.apache.spark.sql.functions",       "avg",       "arr_delay"     )     gdf %>% invoke("agg", expr, list()) }  flights %>%    sdf_pivot(origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay) 

SQL:

Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.

CREATE TEMPORARY VIEW flights  USING csv  OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;   SELECT * FROM (    SELECT origin, dest, carrier, arr_delay, hour FROM flights  ) PIVOT (    avg(arr_delay)    FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,                 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)  ); 

Example data:

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00 

Performance considerations:

Generally speaking pivoting is an expensive operation.

  • if you can, try to provide values list, as this avoids an extra hit to compute the uniques:

    vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop 
  • in some cases it proved to be beneficial (likely no longer worth the effort in 2.0 or later) to repartition and / or pre-aggregate the data

  • for reshaping only, you can use first: Pivot String column on Pyspark Dataframe

Related questions:

  • How to melt Spark DataFrame?
  • Unpivot in spark-sql/pyspark
  • Transpose column to row with Spark
like image 155
15 revs, 6 users 60% Avatar answered Oct 06 '22 13:10

15 revs, 6 users 60%


I overcame this by writing a for loop to dynamically create a SQL query. Say I have:

id  tag  value 1   US    50 1   UK    100 1   Can   125 2   US    75 2   UK    150 2   Can   175 

and I want:

id  US  UK   Can 1   50  100  125 2   75  150  175 

I can create a list with the value I want to pivot and then create a string containing the SQL query I need.

val countries = List("US", "UK", "Can") val numCountries = countries.length - 1  var query = "select *, " for (i <- 0 to numCountries-1) {   query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " } query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"  myDataFrame.registerTempTable("myTable") val myDF1 = sqlContext.sql(query) 

I can create similar query to then do the aggregation. Not a very elegant solution but it works and is flexible for any list of values, which can also be passed in as an argument when your code is called.

like image 22
J Calbreath Avatar answered Oct 06 '22 12:10

J Calbreath