Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark dataframe transform multiple rows to column

Tags:

I am a novice to spark, and I want to transform below source dataframe (load from JSON file):

+--+-----+-----+
|A |count|major|
+--+-----+-----+
| a|    1|   m1|
| a|    1|   m2|
| a|    2|   m3|
| a|    3|   m4|
| b|    4|   m1|
| b|    1|   m2|
| b|    2|   m3|
| c|    3|   m1|
| c|    4|   m3|
| c|    5|   m4|
| d|    6|   m1|
| d|    1|   m2|
| d|    2|   m3|
| d|    3|   m4|
| d|    4|   m5|
| e|    4|   m1|
| e|    5|   m2|
| e|    1|   m3|
| e|    1|   m4|
| e|    1|   m5|
+--+-----+-----+

Into below result dataframe:

+--+--+--+--+--+--+
|A |m1|m2|m3|m4|m5|
+--+--+--+--+--+--+
| a| 1| 1| 2| 3| 0|
| b| 4| 2| 1| 0| 0|
| c| 3| 0| 4| 5| 0|
| d| 6| 1| 2| 3| 4|
| e| 4| 5| 1| 1| 1|
+--+--+--+--+--+--+

Here is the Transformation Rule:

  1. The result dataframe is consisted with A + (n major columns) where the major columns names are specified by:

    sorted(src_df.map(lambda x: x[2]).distinct().collect())
    
  2. The result dataframe contains m rows where the values for A column are provided by:

    sorted(src_df.map(lambda x: x[0]).distinct().collect())
    
  3. The value for each major column in result dataframe is the value from source dataframe on the corresponding A and major (e.g. the count in Row 1 in source dataframe is mapped to the box where A is a and column m1)

  4. The combinations of A and major in source dataframe has no duplication (please consider it a primary key on the two columns in SQL)

like image 580
resec Avatar asked Nov 16 '15 09:11

resec


People also ask

How do I convert rows to columns in spark?

Spark SQL provides a pivot() function to rotate the data from one column into multiple columns (transpose row to column). It is an aggregation where one of the grouping columns values is transposed into individual columns with distinct data.

How do I pivot rows to columns in spark SQL?

Spark pivot() function is used to pivot/rotate the data from one DataFrame/Dataset column into multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).

What is pivot in spark?

One of the many new features added in Spark 1.6 was the ability to pivot data, creating pivot tables, with a DataFrame (with Scala, Java, or Python). A pivot is an aggregation where one (or more in the general case) of the grouping columns has its distinct values transposed into individual columns.


2 Answers

Using zero323's dataframe,

df = sqlContext.createDataFrame([
("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
("e", 1, "m4"), ("e", 1, "m5")], 
("a", "cnt", "major"))

you could also use

reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)
like image 148
TrentWoodbury Avatar answered Oct 13 '22 09:10

TrentWoodbury


Lets start with example data:

df = sqlContext.createDataFrame([
    ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
    ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
    ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
    ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
    ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
    ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
    ("e", 1, "m4"), ("e", 1, "m5")], 
    ("a", "cnt", "major"))

Please note that I've changed count to cnt. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.

There are at least two ways to reshape this data:

  • aggregating over DataFrame

    from pyspark.sql.functions import col, when, max
    
    majors = sorted(df.select("major")
        .distinct()
        .map(lambda row: row[0])
        .collect())
    
    cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m) 
        for m in  majors]
    maxs = [max(col(m)).alias(m) for m in majors]
    
    reshaped1 = (df
        .select(col("a"), *cols)
        .groupBy("a")
        .agg(*maxs)
        .na.fill(0))
    
    reshaped1.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## |  e|  4|  5|  1|  1|  1|
    ## +---+---+---+---+---+---+
    
  • groupBy over RDD

    from pyspark.sql import Row
    
    grouped = (df
        .map(lambda row: (row.a, (row.major, row.cnt)))
        .groupByKey())
    
    def make_row(kv):
        k, vs = kv
        tmp = dict(list(vs) + [("a", k)])
        return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors})
    
    reshaped2 = sqlContext.createDataFrame(grouped.map(make_row))
    
    reshaped2.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  e|  4|  5|  1|  1|  1|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## +---+---+---+---+---+---+
    
like image 20
zero323 Avatar answered Oct 13 '22 08:10

zero323