Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to sort array of struct type in Spark DataFrame by particular column?

Given following code:

import java.sql.Date
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object SortQuestion extends App{

  val spark = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
  import spark.implicits._
  case class ABC(a: Int, b: Int, c: Int)

  val first = Seq(
    ABC(1, 2, 3),
    ABC(1, 3, 4),
    ABC(2, 4, 5),
    ABC(2, 5, 6)
  ).toDF("a", "b", "c")

  val second = Seq(
    (1, 2, (Date.valueOf("2018-01-02"), 30)),
    (1, 3, (Date.valueOf("2018-01-01"), 20)),
    (2, 4, (Date.valueOf("2018-01-02"), 50)),
    (2, 5, (Date.valueOf("2018-01-01"), 60))
  ).toDF("a", "b", "c")

  first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b")).groupBy("a").agg(sort_array(collect_list("c2")))
    .show(false)

}

Spark produces following result:

+---+----------------------------------+
|a  |sort_array(collect_list(c2), true)|
+---+----------------------------------+
|1  |[[2018-01-01,20], [2018-01-02,30]]|
|2  |[[2018-01-01,60], [2018-01-02,50]]|
+---+----------------------------------+

This implies that Spark is sorting an array by date (since it is the first field), but I want to instruct Spark to sort by specific field from that nested struct.

I know I can reshape array to (value, date) but it seems inconvenient, I want a general solution (imagine I have a big nested struct, 5 layers deep, and I want to sort that structure by particular column).

Is there a way to do that? Am I missing something?

like image 815
addmeaning Avatar asked Apr 05 '18 11:04

addmeaning


3 Answers

According to the Hive Wiki:

sort_array(Array<T>) : Sorts the input array in ascending order according to the natural ordering of the array elements and returns it (as of version 0.9.0).

This means that the array will be sorted lexicographically which holds true even with complex data types.

Alternatively, you can create a UDF to sort it (and witness performance degradation) based on the second element:

val sortUdf = udf { (xs: Seq[Row]) => xs.sortBy(_.getAs[Int](1) )
                                        .map{ case Row(x:java.sql.Date, y: Int) => (x,y) }}

first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
     .groupBy("a")
     .agg(sortUdf(collect_list("c2")))
     .show(false)

//+---+----------------------------------+
//|a  |UDF(collect_list(c2, 0, 0))       |
//+---+----------------------------------+
//|1  |[[2018-01-01,20], [2018-01-02,30]]|
//|2  |[[2018-01-02,50], [2018-01-01,60]]|
//+---+----------------------------------+
like image 197
philantrovert Avatar answered Oct 08 '22 03:10

philantrovert


If you have complex object it is much better to use statically typed Dataset.

case class Result(a: Int, b: Int, c: Int, c2: (java.sql.Date, Int))

val joined = first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
joined.as[Result]
  .groupByKey(_.a)
  .mapGroups((key, xs) => (key, xs.map(_.c2).toSeq.sortBy(_._2)))
  .show(false)

// +---+----------------------------------+            
// |_1 |_2                                |
// +---+----------------------------------+
// |1  |[[2018-01-01,20], [2018-01-02,30]]|
// |2  |[[2018-01-02,50], [2018-01-01,60]]|
// +---+----------------------------------+

In simple cases it is also possible to udf, but leads to inefficient and fragile code in general and quickly goes out of control, when complexity of objects grows.

like image 41
Alper t. Turker Avatar answered Oct 08 '22 03:10

Alper t. Turker


For Spark 3+, you can pass a custom comparator function to array_sort:

The comparator will take two arguments representing two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error.

val df = first
  .join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
  .groupBy("a")
  .agg(collect_list("c2").alias("list"))

val df2 = df.withColumn(
  "list",
  expr(
    "array_sort(list, (left, right) -> case when left._2 < right._2 then -1 when left._2 > right._2 then 1 else 0 end)"
  )
)

df2.show(false)
//+---+------------------------------------+
//|a  |list                                |
//+---+------------------------------------+
//|1  |[[2018-01-01, 20], [2018-01-02, 30]]|
//|2  |[[2018-01-02, 50], [2018-01-01, 60]]|
//+---+------------------------------------+

Where _2 is the name of the struct field you wan to use for sorting

like image 4
blackbishop Avatar answered Oct 08 '22 03:10

blackbishop