I am trying to 'lift' the fields of a struct to the top level in a dataframe, as illustrated by this example:
case class A(a1: String, a2: String)
case class B(b1: String, b2: A)
val df = Seq(B("X",A("Y","Z"))).toDF
df.show
+---+-----+
| b1| b2|
+---+-----+
| X|[Y,Z]|
+---+-----+
df.printSchema
root
|-- b1: string (nullable = true)
|-- b2: struct (nullable = true)
| |-- a1: string (nullable = true)
| |-- a2: string (nullable = true)
val lifted = df.withColumn("a1", $"b2.a1").withColumn("a2", $"b2.a2").drop("b2")
lifted.show
+---+---+---+
| b1| a1| a2|
+---+---+---+
| X| Y| Z|
+---+---+---+
lifted.printSchema
root
|-- b1: string (nullable = true)
|-- a1: string (nullable = true)
|-- a2: string (nullable = true)
This works. I would like to create a little utility method which does this for me, probably through pimping DataFrame to enable something like df.lift("b2").
To do this, I think I want a way of obtaining a list of all fields within a Struct. E.g. given "b2" as input, return ["a1","a2"]. How do I do this?
If I understand your question correctly, you want to be able to list the nested fields of column b2.
So you would need to filter on b2, access the StructType of b2 and then map the names of the columns from within the fields (StructField):
import org.apache.spark.sql.types.StructType
val nested_fields = df.schema
.filter(c => c.name == "b2")
.flatMap(_.dataType.asInstanceOf[StructType].fields)
.map(_.name)
// nested_fields: Seq[String] = List(a1, a2)
Actually you can use ".fieldNames.toList".
val nested_fields = df.schema("b2").fieldNames.toList
It returns a list of String. If you want a list of columns make a map.
I hope it helps.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With