Querying Spark SQL DataFrame with complex types

How Can I query an RDD with complex types such as maps/arrays? for example, when I was writing this test code:

case class Test(name: String, map: Map[String, String]) val map = Map("hello" -> "world", "hey" -> "there") val map2 = Map("hello" -> "people", "hey" -> "you") val rdd = sc.parallelize(Array(Test("first", map), Test("second", map2))) 

I thought the syntax would be something like:

sqlContext.sql("SELECT * FROM rdd WHERE map.hello = world") 


sqlContext.sql("SELECT * FROM rdd WHERE map[hello] = world") 

but I get

Can't access nested field in type MapType(StringType,StringType,true)


org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Unresolved attributes


1 Answers

It depends on a type of the column. Lets start with some dummy data:

import org.apache.spark.sql.functions.{udf, lit} import scala.util.Try  case class SubRecord(x: Int) case class ArrayElement(foo: String, bar: Int, vals: Array[Double]) case class Record(   an_array: Array[Int], a_map: Map[String, String],    a_struct: SubRecord, an_array_of_structs: Array[ArrayElement])   val df = sc.parallelize(Seq(   Record(Array(1, 2, 3), Map("foo" -> "bar"), SubRecord(1),          Array(            ArrayElement("foo", 1, Array(1.0, 2.0, 2.0)),            ArrayElement("bar", 2, Array(3.0, 4.0, 5.0)))),   Record(Array(4, 5, 6), Map("foz" -> "baz"), SubRecord(2),          Array(ArrayElement("foz", 3, Array(5.0, 6.0)),                 ArrayElement("baz", 4, Array(7.0, 8.0)))) )).toDF 
df.registerTempTable("df") df.printSchema  // root // |-- an_array: array (nullable = true) // |    |-- element: integer (containsNull = false) // |-- a_map: map (nullable = true) // |    |-- key: string // |    |-- value: string (valueContainsNull = true) // |-- a_struct: struct (nullable = true) // |    |-- x: integer (nullable = false) // |-- an_array_of_structs: array (nullable = true) // |    |-- element: struct (containsNull = true) // |    |    |-- foo: string (nullable = true) // |    |    |-- bar: integer (nullable = false) // |    |    |-- vals: array (nullable = true) // |    |    |    |-- element: double (containsNull = false) 
  • array (ArrayType) columns:

    • Column.getItem method

      df.select($"an_array".getItem(1)).show  // +-----------+ // |an_array[1]| // +-----------+ // |          2| // |          5| // +-----------+ 
    • Hive brackets syntax:

      sqlContext.sql("SELECT an_array[1] FROM df").show  // +---+ // |_c0| // +---+ // |  2| // |  5| // +---+ 
    • an UDF

      val get_ith = udf((xs: Seq[Int], i: Int) => Try(xs(i)).toOption)  df.select(get_ith($"an_array", lit(1))).show  // +---------------+ // |UDF(an_array,1)| // +---------------+ // |              2| // |              5| // +---------------+ 
    • Additionally to the methods listed above Spark supports a growing list of built-in functions operating on complex types. Notable examples include higher order functions like transform (SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):

      df.selectExpr("transform(an_array, x -> x + 1) an_array_inc").show // +------------+ // |an_array_inc| // +------------+ // |   [2, 3, 4]| // |   [5, 6, 7]| // +------------+  import org.apache.spark.sql.functions.transform  df.select(transform($"an_array", x => x + 1) as "an_array_inc").show // +------------+ // |an_array_inc| // +------------+ // |   [2, 3, 4]| // |   [5, 6, 7]| // +------------+ 
    • filter (SQL 2.4+, Scala 3.0+, Python / SparkR 3.1+)

      df.selectExpr("filter(an_array, x -> x % 2 == 0) an_array_even").show // +-------------+ // |an_array_even| // +-------------+ // |          [2]| // |       [4, 6]| // +-------------+  import org.apache.spark.sql.functions.filter  df.select(filter($"an_array", x => x % 2 === 0) as "an_array_even").show // +-------------+ // |an_array_even| // +-------------+ // |          [2]| // |       [4, 6]| // +-------------+ 
    • aggregate (SQL 2.4+, Scala 3.0+, PySpark / SparkR 3.1+):

      df.selectExpr("aggregate(an_array, 0, (acc, x) -> acc + x, acc -> acc) an_array_sum").show // +------------+ // |an_array_sum| // +------------+ // |           6| // |          15| // +------------+  import org.apache.spark.sql.functions.aggregate  df.select(aggregate($"an_array", lit(0), (x, y) => x + y) as "an_array_sum").show // +------------+                                                                   // |an_array_sum| // +------------+ // |           6| // |          15| // +------------+ 
    • array processing functions (array_*) like array_distinct (2.4+):

      import org.apache.spark.sql.functions.array_distinct  df.select(array_distinct($"an_array_of_structs.vals"(0))).show // +-------------------------------------------+ // |array_distinct(an_array_of_structs.vals[0])| // +-------------------------------------------+ // |                                 [1.0, 2.0]| // |                                 [5.0, 6.0]| // +-------------------------------------------+ 
    • array_max (array_min, 2.4+):

      import org.apache.spark.sql.functions.array_max  df.select(array_max($"an_array")).show // +-------------------+ // |array_max(an_array)| // +-------------------+ // |                  3| // |                  6| // +-------------------+ 
    • flatten (2.4+)

      import org.apache.spark.sql.functions.flatten  df.select(flatten($"an_array_of_structs.vals")).show // +---------------------------------+ // |flatten(an_array_of_structs.vals)| // +---------------------------------+ // |             [1.0, 2.0, 2.0, 3...| // |             [5.0, 6.0, 7.0, 8.0]| // +---------------------------------+ 
    • arrays_zip (2.4+):

      import org.apache.spark.sql.functions.arrays_zip  df.select(arrays_zip($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show(false) // +--------------------------------------------------------------------+ // |arrays_zip(an_array_of_structs.vals[0], an_array_of_structs.vals[1])| // +--------------------------------------------------------------------+ // |[[1.0, 3.0], [2.0, 4.0], [2.0, 5.0]]                                | // |[[5.0, 7.0], [6.0, 8.0]]                                            | // +--------------------------------------------------------------------+ 
    • array_union (2.4+):

      import org.apache.spark.sql.functions.array_union  df.select(array_union($"an_array_of_structs.vals"(0), $"an_array_of_structs.vals"(1))).show // +---------------------------------------------------------------------+ // |array_union(an_array_of_structs.vals[0], an_array_of_structs.vals[1])| // +---------------------------------------------------------------------+ // |                                                 [1.0, 2.0, 3.0, 4...| // |                                                 [5.0, 6.0, 7.0, 8.0]| // +---------------------------------------------------------------------+ 
    • slice (2.4+):

      import org.apache.spark.sql.functions.slice  df.select(slice($"an_array", 2, 2)).show // +---------------------+ // |slice(an_array, 2, 2)| // +---------------------+ // |               [2, 3]| // |               [5, 6]| // +---------------------+ 
  • map (MapType) columns

    • using Column.getField method:

      df.select($"a_map".getField("foo")).show  // +----------+ // |a_map[foo]| // +----------+ // |       bar| // |      null| // +----------+ 
    • using Hive brackets syntax:

      sqlContext.sql("SELECT a_map['foz'] FROM df").show  // +----+ // | _c0| // +----+ // |null| // | baz| // +----+ 
    • using a full path with dot syntax:

      df.select($"a_map.foo").show  // +----+ // | foo| // +----+ // | bar| // |null| // +----+ 
    • using an UDF

      val get_field = udf((kvs: Map[String, String], k: String) => kvs.get(k))  df.select(get_field($"a_map", lit("foo"))).show  // +--------------+ // |UDF(a_map,foo)| // +--------------+ // |           bar| // |          null| // +--------------+ 
    • Growing number of map_* functions like map_keys (2.3+)

      import org.apache.spark.sql.functions.map_keys  df.select(map_keys($"a_map")).show // +---------------+ // |map_keys(a_map)| // +---------------+ // |          [foo]| // |          [foz]| // +---------------+ 
    • or map_values (2.3+)

      import org.apache.spark.sql.functions.map_values  df.select(map_values($"a_map")).show // +-----------------+ // |map_values(a_map)| // +-----------------+ // |            [bar]| // |            [baz]| // +-----------------+ 

    Please check SPARK-23899 for a detailed list.

  • struct (StructType) columns using full path with dot syntax:

    • with DataFrame API

      df.select($"a_struct.x").show  // +---+ // |  x| // +---+ // |  1| // |  2| // +---+ 
    • with raw SQL

      sqlContext.sql("SELECT a_struct.x FROM df").show  // +---+ // |  x| // +---+ // |  1| // |  2| // +---+ 
  • fields inside array of structs can be accessed using dot-syntax, names and standard Column methods:

    df.select($"an_array_of_structs.foo").show  // +----------+ // |       foo| // +----------+ // |[foo, bar]| // |[foz, baz]| // +----------+  sqlContext.sql("SELECT an_array_of_structs[0].foo FROM df").show  // +---+ // |_c0| // +---+ // |foo| // |foz| // +---+  df.select($"an_array_of_structs.vals".getItem(1).getItem(1)).show  // +------------------------------+ // |an_array_of_structs.vals[1][1]| // +------------------------------+ // |                           4.0| // |                           8.0| // +------------------------------+ 
  • user defined types (UDTs) fields can be accessed using UDFs. See Spark SQL referencing attributes of UDT for details.


  • depending on a Spark version some of these methods can be available only with HiveContext. UDFs should work independent of version with both standard SQLContext and HiveContext.
  • generally speaking nested values are a second class citizens. Not all typical operations are supported on nested fields. Depending on a context it could be better to flatten the schema and / or explode collections

    df.select(explode($"an_array_of_structs")).show  // +--------------------+ // |                 col| // +--------------------+ // |[foo,1,WrappedArr...| // |[bar,2,WrappedArr...| // |[foz,3,WrappedArr...| // |[baz,4,WrappedArr...| // +--------------------+ 
  • Dot syntax can be combined with wildcard character (*) to select (possibly multiple) fields without specifying names explicitly:

    df.select($"a_struct.*").show // +---+ // |  x| // +---+ // |  1| // |  2| // +---+ 
  • JSON columns can be queried using get_json_object and from_json functions. See How to query JSON data column using Spark DataFrames? for details.

