Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

spark pivot without aggregation

https://databricks.com/blog/2016/02/09/reshaping-data-with-pivot-in-apache-spark.html explain nicely how a pivot works for spark.

In my python code, I use pandas without an aggregation but reset the index and join:

pd.pivot_table(data=dfCountries, index=['A'], columns=['B'])
countryToMerge.index.name = 'ISO'
df.merge(countryToMerge['value'].reset_index(), on='ISO', how='inner')

How does this work in spark?

I tried to group and join manually like:

val grouped = countryKPI.groupBy("A").pivot("B")
df.join(grouped, df.col("ISO") === grouped.col("ISO")).show

but that does not work. How would the reset_index fit into spark / How would it be implemented in a spark native way?

edit

a minimal example of the python code:

import pandas as pd
from datetime import datetime, timedelta
import numpy as np
dates = pd.DataFrame([(datetime(2016, 1, 1) + timedelta(i)).strftime('%Y-%m-%d') for i in range(10)], columns=["dates"])
isos = pd.DataFrame(["ABC", "POL", "ABC", "POL","ABC", "POL","ABC", "POL","ABC", "POL"], columns=['ISO'])
dates['ISO'] = isos.ISO
dates['ISO'] = dates['ISO'].astype("category")
countryKPI = pd.DataFrame({'country_id3':['ABC','POL','ABC','POL'],
                       'indicator_id':['a','a','b','b'],
                       'value':[7,8,9,7]})
countryToMerge = pd.pivot_table(data=countryKPI, index=['country_id3'], columns=['indicator_id'])
countryToMerge.index.name = 'ISO'
print(dates.merge(countryToMerge['value'].reset_index(), on='ISO', how='inner'))

  dates  ISO  a  b
0  2016-01-01  ABC  7  9
1  2016-01-03  ABC  7  9
2  2016-01-05  ABC  7  9
3  2016-01-07  ABC  7  9
4  2016-01-09  ABC  7  9
5  2016-01-02  POL  8  7
6  2016-01-04  POL  8  7
7  2016-01-06  POL  8  7
8  2016-01-08  POL  8  7
9  2016-01-10  POL  8  7

to follow along in scala / spark

val dates = Seq(("2016-01-01", "ABC"),
    ("2016-01-02", "ABC"),
    ("2016-01-03", "POL"),
    ("2016-01-04", "ABC"),
    ("2016-01-05", "POL"),
    ("2016-01-06", "ABC"),
    ("2016-01-07", "POL"),
    ("2016-01-08", "ABC"),
    ("2016-01-09", "POL"),
    ("2016-01-10", "ABC")
  ).toDF("dates", "ISO")
    .withColumn("dates", 'dates.cast("Date"))

  dates.show
  dates.printSchema

  val countryKPI = Seq(("ABC", "a", 7),
    ("ABC", "b", 8),
    ("POL", "a", 9),
    ("POL", "b", 7)
  ).toDF("country_id3", "indicator_id", "value")

  countryKPI.show
  countryKPI.printSchema

val grouped = countryKPI.groupBy("country_id3").pivot("indicator_id")
like image 211
Georg Heiler Avatar asked Nov 22 '16 22:11

Georg Heiler


People also ask

How do you pivot in spark?

To do the same group/pivot/sum in Spark the syntax is df. groupBy("A", "B"). pivot("C").


2 Answers

The following snippet seems to work - but I am not sure if an aggregation by avg is correct -even though "fitting numbers" are the output.

countryKPI.groupBy("country_id3").pivot("indicator_id").avg("value").show

I'm not sure if this is "inefficient" for a bigger amount of data (avg) compared to just reusing the values (as I do not want to aggregate).

like image 90
Georg Heiler Avatar answered Oct 15 '22 20:10

Georg Heiler


There isn't a good way to pivot without aggregating in Spark, basically it assumes that you would just use a OneHotEncoder for that functionality, but that lacks the human readability of a straight pivot. The best ways that I have found to do it are:

val pivot = countryKPI
  .groupBy("country_id3", "value")
  .pivot("indicator_id", Seq("a", "b"))
  .agg(first(col("indicator_id")))

pivot.show
+-----------+-----+----+----+
|country_id3|value|   a|   b|
+-----------+-----+----+----+
|        ABC|    8|null|   b|
|        POL|    9|   a|null|
|        POL|    7|null|   b|
|        ABC|    7|   a|null|
+-----------+-----+----+----+

However, if (country_id3, value) is not distinct within the dataset, then you collapse rows and potentially be taking a somewhat meaningless first() value from your pivot col.

An alternative is to add an id column to the dataset, group on that new id, pivot your desired column, then join back to the original dataset. Here's an example:

val countryWithId = countryKPI.withColumn("id", monotonically_increasing_id)
val pivotted = countryWithId
.groupBy("id")
.pivot("indicator_id")
.agg(first(col("indicator_id")))

val pivot2 = countryWithId.join(pivotted, Seq("id")).drop("id") //.drop("indicator_id")

pivot2.show
+-----------+------------+-----+----+----+
|country_id3|indicator_id|value|   a|   b|
+-----------+------------+-----+----+----+
|        ABC|           a|    7|   a|null|
|        ABC|           b|    8|null|   b|
|        POL|           a|    9|   a|null|
|        POL|           b|    7|null|   b|
+-----------+------------+-----+----+----+

In this case, you still have the original pivot column, but you can .drop() that as well if you prefer.

like image 43
Derek Kaknes Avatar answered Oct 15 '22 20:10

Derek Kaknes