Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark explode stringified array of dictionaries into rows

I have a pyspark dataframe with StringType column (edges), which contains a list of dictionaries (see example below). The dictionaries contain a mix of value types, including another dictionary (nodeIDs). I need to explode the top-level dictionaries in the edges field into rows; ideally, I should then be able to convert their component values into separate fields.

Input:

import findspark
findspark.init()

SPARK = SparkSession.builder.enableHiveSupport() \
                    .getOrCreate()

data = [
    Row(trace_uuid='aaaa', timestamp='2019-05-20T10:36:33+02:00', edges='[{"distance":4.382441320292239,"duration":1.5,"speed":2.9,"nodeIDs":{"nodeA":954752475,"nodeB":1665827480}},{"distance":14.48582171131768,"duration":2.6,"speed":5.6,"nodeIDs":{"nodeA":1665827480,"nodeB":3559056131}}]', count=156, level=36),
    Row(trace_uuid='bbbb', timestamp='2019-05-20T11:36:10+03:00', edges='[{"distance":0,"duration":0,"speed":0,"nodeIDs":{"nodeA":520686131,"nodeB":520686216}},{"distance":8.654358326561642,"duration":3.1,"speed":2.8,"nodeIDs":{"nodeA":520686216,"nodeB":506361795}}]', count=179, level=258)
    ]

df = SPARK.createDataFrame(data)

Desired output:

    data_reshaped = [
        Row(trace_uuid='aaaa', timestamp='2019-05-20T10=36=33+02=00', distance=4.382441320292239, duration=1.5, speed=2.9, nodeA=954752475, nodeB=1665827480, count=156, level=36),
        Row(trace_uuid='aaaa', timestamp='2019-05-20T10=36=33+02=00', distance=16.134844841712574, duration=2.9,speed=5.6, nodeA=1665827480, nodeB=3559056131, count=156, level=36),
        Row(trace_uuid='bbbb', timestamp='2019-05-20T11=36=10+03=00', distance=0, duration=0, speed=0, nodeA=520686131, nodeB=520686216, count=179, level=258),
        Row(trace_uuid='bbbb', timestamp='2019-05-20T11=36=10+03=00', distance=8.654358326561642, duration=3.1, speed=2.8, nodeA=520686216, nodeB=506361795, count=179, level=258)
       ]

Is there a way to do that? I've tried using cast to cast the edges field into an array first, but I can't figure out how to get it to work with the mixed data types.

I'm using Spark 2.4.0.

like image 473
SoHei Avatar asked Jun 14 '19 00:06

SoHei


People also ask

How do you explode rows in PySpark?

explode() – PySpark explode array or map column to rows PySpark function explode(e: Column) is used to explode or create array or map columns to rows. When an array is passed to this function, it creates a new default column “col1” and it contains all array elements.

How do you flatten an array in PySpark?

If you want to flatten the arrays, use flatten function which converts array of array columns to a single array on DataFrame.

How do you explode an array of struct in spark?

Solution: Spark explode function can be used to explode an Array of Struct ArrayType(StructType) columns to rows on Spark DataFrame using scala example. Before we start, let's create a DataFrame with Struct column in an array.


1 Answers

You can use from_json() with schema_of_json() to infer the JSON schema. for example:

from pyspark.sql import functions as F

# a sample json string:  
edges_json_sample = data[0].edges
# or edges_json_sample = df.select('edges').first()[0]

>>> edges_json_sample
#'[{"distance":4.382441320292239,"duration":1.5,"speed":2.9,"nodeIDs":{"nodeA":954752475,"nodeB":1665827480}},{"distance":14.48582171131768,"duration":2.6,"speed":5.6,"nodeIDs":{"nodeA":1665827480,"nodeB":3559056131}}]'

# infer schema from the sample string
schema = df.select(F.schema_of_json(edges_json_sample)).first()[0]

>>> schema
#u'array<struct<distance:double,duration:double,nodeIDs:struct<nodeA:bigint,nodeB:bigint>,speed:double>>'

# convert json string to data structure and then retrieve desired items
new_df = df.withColumn('data', F.explode(F.from_json('edges', schema))) \
           .select('*', 'data.*', 'data.nodeIDs.*') \
           .drop('data', 'nodeIDs', 'edges')
           
>>> new_df.show()
+-----+-----+--------------------+----------+-----------------+--------+-----+----------+----------+
|count|level|           timestamp|trace_uuid|         distance|duration|speed|     nodeA|     nodeB|
+-----+-----+--------------------+----------+-----------------+--------+-----+----------+----------+
|  156|   36|2019-05-20T10:36:...|      aaaa|4.382441320292239|     1.5|  2.9| 954752475|1665827480|
|  156|   36|2019-05-20T10:36:...|      aaaa|14.48582171131768|     2.6|  5.6|1665827480|3559056131|
|  179|  258|2019-05-20T11:36:...|      bbbb|              0.0|     0.0|  0.0| 520686131| 520686216|
|  179|  258|2019-05-20T11:36:...|      bbbb|8.654358326561642|     3.1|  2.8| 520686216| 506361795|
+-----+-----+--------------------+----------+-----------------+--------+-----+----------+----------+

# expected result
data_reshaped = new_df.rdd.collect()
like image 173
jxc Avatar answered Oct 23 '22 01:10

jxc