I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in Python, but I can't find anything for the new Spark Dataframe.
I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. If anyone knows how to do this with built-in functionality or suggestions for how to write something in Scala, it is greatly appreciated.
PySpark pivot() function is used to rotate/transpose the data from one column into multiple Dataframe columns and back using unpivot(). Pivot() It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.
Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).
As mentioned by David Anderson Spark provides pivot
function since version 1.6. General syntax looks as follows:
df .groupBy(grouping_columns) .pivot(pivot_column, [values]) .agg(aggregate_expressions)
Usage examples using nycflights13
and csv
format:
Python:
from pyspark.sql.functions import avg flights = (sqlContext .read .format("csv") .options(inferSchema="true", header="true") .load("flights.csv") .na.drop()) flights.registerTempTable("flights") sqlContext.cacheTable("flights") gexprs = ("origin", "dest", "carrier") aggexpr = avg("arr_delay") flights.count() ## 336776 %timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count() ## 10 loops, best of 3: 1.03 s per loop
Scala:
val flights = sqlContext .read .format("csv") .options(Map("inferSchema" -> "true", "header" -> "true")) .load("flights.csv") flights .groupBy($"origin", $"dest", $"carrier") .pivot("hour") .agg(avg($"arr_delay"))
Java:
import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.*; Dataset<Row> df = spark.read().format("csv") .option("inferSchema", "true") .option("header", "true") .load("flights.csv"); df.groupBy(col("origin"), col("dest"), col("carrier")) .pivot("hour") .agg(avg(col("arr_delay")));
R / SparkR:
library(magrittr) flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE) flights %>% groupBy("origin", "dest", "carrier") %>% pivot("hour") %>% agg(avg(column("arr_delay")))
R / sparklyr
library(dplyr) flights <- spark_read_csv(sc, "flights", "flights.csv") avg.arr.delay <- function(gdf) { expr <- invoke_static( sc, "org.apache.spark.sql.functions", "avg", "arr_delay" ) gdf %>% invoke("agg", expr, list()) } flights %>% sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay)
SQL:
Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.
CREATE TEMPORARY VIEW flights USING csv OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ; SELECT * FROM ( SELECT origin, dest, carrier, arr_delay, hour FROM flights ) PIVOT ( avg(arr_delay) FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23) );
Example data:
"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour" 2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00 2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00 2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00 2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00 2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00 2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00 2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00 2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00 2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00
Performance considerations:
Generally speaking pivoting is an expensive operation.
if you can, try to provide values
list, as this avoids an extra hit to compute the uniques:
vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop
in some cases it proved to be beneficial (likely no longer worth the effort in 2.0 or later) to repartition
and / or pre-aggregate the data
for reshaping only, you can use first
: Pivot String column on Pyspark Dataframe
Related questions:
I overcame this by writing a for loop to dynamically create a SQL query. Say I have:
id tag value 1 US 50 1 UK 100 1 Can 125 2 US 75 2 UK 150 2 Can 175
and I want:
id US UK Can 1 50 100 125 2 75 150 175
I can create a list with the value I want to pivot and then create a string containing the SQL query I need.
val countries = List("US", "UK", "Can") val numCountries = countries.length - 1 var query = "select *, " for (i <- 0 to numCountries-1) { query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", " } query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable" myDataFrame.registerTempTable("myTable") val myDF1 = sqlContext.sql(query)
I can create similar query to then do the aggregation. Not a very elegant solution but it works and is flexible for any list of values, which can also be passed in as an argument when your code is called.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With