Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to unwrap nested Struct column into multiple columns?

I'm trying to expand a DataFrame column with nested struct type (see below) to multiple columns. The Struct schema I'm working with looks something like {"foo": 3, "bar": {"baz": 2}}.

Ideally, I'd like to expand the above into two columns ("foo" and "bar.baz"). However, when I tried using .select("data.*") (where data is the Struct column), I only get columns foo and bar, where bar is still a struct.

Is there a way such that I can expand the Struct for both layers?

like image 391
Zz'Rot Avatar asked Oct 24 '17 14:10

Zz'Rot


People also ask

How do you flatten a struct in SQL?

SELECT * FROM a_join_b ; it will flatten the structs and provide a table with fields named a_field1, a_field2, ..., b_field1, b_field2. Note the underscores between the table names and the field names, and that a and b can have similar field names.

How do you select nested columns in PySpark?

In PySpark, select() function is used to select single, multiple, column by index, all columns from the list and the nested columns from a DataFrame, PySpark select() is a transformation function hence it returns a new DataFrame with the selected columns.


2 Answers

You can select data.bar.baz as bar.baz:

df.show()
+-------+
|   data|
+-------+
|[3,[2]]|
+-------+

df.printSchema()
root
 |-- data: struct (nullable = false)
 |    |-- foo: long (nullable = true)
 |    |-- bar: struct (nullable = false)
 |    |    |-- baz: long (nullable = true)

In pyspark:

import pyspark.sql.functions as F
df.select(F.col("data.foo").alias("foo"), F.col("data.bar.baz").alias("bar.baz")).show()
+---+-------+
|foo|bar.baz|
+---+-------+
|  3|      2|
+---+-------+
like image 145
Psidom Avatar answered Oct 17 '22 12:10

Psidom


I ended up going for the following function that recursively "unwraps" layered Struct's:

Essentially, it keeps digging into Struct fields and leave the other fields intact, and this approach eliminates the need to have a very long df.select(...) statement when the Struct has a lot of fields. Here's the code:

# Takes in a StructType schema object and return a column selector that flattens the Struct
def flatten_struct(schema, prefix=""):
    result = []
    for elem in schema:
        if isinstance(elem.dataType, StructType):
            result += flatten_struct(elem.dataType, prefix + elem.name + ".")
        else:
            result.append(col(prefix + elem.name).alias(prefix + elem.name))
    return result


df = sc.parallelize([Row(r=Row(a=1, b=Row(foo="b", bar="12")))]).toDF()
df.show()
+----------+
|         r|
+----------+
|[1,[12,b]]|
+----------+

df_expanded = df.select("r.*")
df_flattened = df_expanded.select(flatten_struct(df_expanded.schema))

df_flattened.show()
+---+-----+-----+
|  a|b.bar|b.foo|
+---+-----+-----+
|  1|   12|    b|
+---+-----+-----+
like image 19
Zz'Rot Avatar answered Oct 17 '22 11:10

Zz'Rot