I'm trying to expand a DataFrame column with nested struct
type (see below) to multiple columns. The Struct schema I'm working with looks something like {"foo": 3, "bar": {"baz": 2}}
.
Ideally, I'd like to expand the above into two columns ("foo"
and "bar.baz"
). However, when I tried using .select("data.*")
(where data
is the Struct column), I only get columns foo
and bar
, where bar
is still a struct
.
Is there a way such that I can expand the Struct for both layers?
SELECT * FROM a_join_b ; it will flatten the structs and provide a table with fields named a_field1, a_field2, ..., b_field1, b_field2. Note the underscores between the table names and the field names, and that a and b can have similar field names.
In PySpark, select() function is used to select single, multiple, column by index, all columns from the list and the nested columns from a DataFrame, PySpark select() is a transformation function hence it returns a new DataFrame with the selected columns.
You can select data.bar.baz
as bar.baz
:
df.show()
+-------+
| data|
+-------+
|[3,[2]]|
+-------+
df.printSchema()
root
|-- data: struct (nullable = false)
| |-- foo: long (nullable = true)
| |-- bar: struct (nullable = false)
| | |-- baz: long (nullable = true)
In pyspark:
import pyspark.sql.functions as F
df.select(F.col("data.foo").alias("foo"), F.col("data.bar.baz").alias("bar.baz")).show()
+---+-------+
|foo|bar.baz|
+---+-------+
| 3| 2|
+---+-------+
I ended up going for the following function that recursively "unwraps" layered Struct's:
Essentially, it keeps digging into Struct
fields and leave the other fields intact, and this approach eliminates the need to have a very long df.select(...)
statement when the Struct
has a lot of fields. Here's the code:
# Takes in a StructType schema object and return a column selector that flattens the Struct
def flatten_struct(schema, prefix=""):
result = []
for elem in schema:
if isinstance(elem.dataType, StructType):
result += flatten_struct(elem.dataType, prefix + elem.name + ".")
else:
result.append(col(prefix + elem.name).alias(prefix + elem.name))
return result
df = sc.parallelize([Row(r=Row(a=1, b=Row(foo="b", bar="12")))]).toDF()
df.show()
+----------+
| r|
+----------+
|[1,[12,b]]|
+----------+
df_expanded = df.select("r.*")
df_flattened = df_expanded.select(flatten_struct(df_expanded.schema))
df_flattened.show()
+---+-----+-----+
| a|b.bar|b.foo|
+---+-----+-----+
| 1| 12| b|
+---+-----+-----+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With