Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to aggregate map columns after groupBy?

I need to union two dataframes and combine the columns by keys. The two datafrmaes have the same schema, for example:

root
|-- id: String (nullable = true)
|-- cMap: map (nullable = true)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

I want to group by "id" and aggregate the "cMap" together to deduplicate. I tried the code:

val df = df_a.unionAll(df_b).groupBy("id").agg(collect_list("cMap") as "cMap").
rdd.map(x => {
    var map = Map[String,String]()
    x.getAs[Seq[Map[String,String]]]("cMap").foreach( y => 
        y.foreach( tuple =>
        {
            val key = tuple._1
            val value = tuple._2
            if(!map.contains(key))//deduplicate
                map += (key -> value)
        }))

    Row(x.getAs[String]("id"),map)
    })

But it seems collect_list cannnot be used to map structure:

org.apache.spark.sql.AnalysisException: No handler for Hive udf class org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList because: Only primitive type arguments are accepted but map<string,string> was passed as parameter 1..;

Is there other solution for the problem?

like image 583
Pingjiang Li Avatar asked Jan 21 '26 11:01

Pingjiang Li


2 Answers

Since Spark 3.0, you can:

  • transform your map to an array of map entries with map_entries
  • collect those arrays by your id using collect_set
  • flatten the collected array of arrays using flatten
  • then rebuild the map from flattened array using map_from_entries

See following code snippet where input is your input dataframe:

import org.apache.spark.sql.functions.{col, collect_set, flatten, map_entries, map_from_entries}

input
  .withColumn("cMap", map_entries(col("cMap")))
  .groupBy("id")
  .agg(map_from_entries(flatten(collect_set("cMap"))).as("cMap"))

Example

Given the following dataframe input:

+---+--------------------+
|id |cMap                |
+---+--------------------+
|1  |[k1 -> v1]          |
|1  |[k2 -> v2, k3 -> v3]|
|2  |[k4 -> v4]          |
|2  |[]                  |
|3  |[k6 -> v6, k7 -> v7]|
+---+--------------------+

The code snippet above returns the following dataframe:

+---+------------------------------+
|id |cMap                          |
+---+------------------------------+
|1  |[k1 -> v1, k2 -> v2, k3 -> v3]|
|3  |[k6 -> v6, k7 -> v7]          |
|2  |[k4 -> v4]                    |
+---+------------------------------+
like image 185
Vincent Doba Avatar answered Jan 24 '26 03:01

Vincent Doba


You have to use explode function on the map columns first to destructure maps into key and value columns, union the result datasets followed by distinct to de-duplicate and only then groupBy with some custom Scala coding to aggregate the maps.

Stop talking and let's do some coding then...

Given the datasets:

scala> a.show(false)
+---+-----------------------+
|id |cMap                   |
+---+-----------------------+
|one|Map(1 -> one, 2 -> two)|
+---+-----------------------+

scala> a.printSchema
root
 |-- id: string (nullable = true)
 |-- cMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

scala> b.show(false)
+---+-------------+
|id |cMap         |
+---+-------------+
|one|Map(1 -> one)|
+---+-------------+

scala> b.printSchema
root
 |-- id: string (nullable = true)
 |-- cMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

You should first use explode function on the map columns.

explode(e: Column): Column Creates a new row for each element in the given array or map column.

val a_keyValues = a.select('*, explode($"cMap"))
scala> a_keyValues.show(false)
+---+-----------------------+---+-----+
|id |cMap                   |key|value|
+---+-----------------------+---+-----+
|one|Map(1 -> one, 2 -> two)|1  |one  |
|one|Map(1 -> one, 2 -> two)|2  |two  |
+---+-----------------------+---+-----+

val b_keyValues = b.select('*, explode($"cMap"))

With the following you have distinct key-value pairs which is exactly deduplication you asked for.

val distinctKeyValues = a_keyValues.
  union(b_keyValues).
  select("id", "key", "value").
  distinct // <-- deduplicate
scala> distinctKeyValues.show(false)
+---+---+-----+
|id |key|value|
+---+---+-----+
|one|1  |one  |
|one|2  |two  |
+---+---+-----+

Time for groupBy and create the final map column.

val result = distinctKeyValues.
  withColumn("map", map($"key", $"value")).
  groupBy("id").
  agg(collect_list("map")).
  as[(String, Seq[Map[String, String]])]. // <-- leave Rows for typed pairs
  map { case (id, list) => (id, list.reduce(_ ++ _)) }. // <-- collect all entries under one map
  toDF("id", "cMap") // <-- give the columns their names
scala> result.show(truncate = false)
+---+-----------------------+
|id |cMap                   |
+---+-----------------------+
|one|Map(1 -> one, 2 -> two)|
+---+-----------------------+

Please note that as of Spark 2.0.0 unionAll has been deprecated and union is the proper union operator:

(Since version 2.0.0) use union()

like image 30
Jacek Laskowski Avatar answered Jan 24 '26 04:01

Jacek Laskowski



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!