I have a DataFrame containing several columns I'd like to use as input to a function which will produce multiple outputs per row, with each output going into a new column.
For example, I have a function that takes address values and parses into finer grain parts:
def parser(address1: str, city: str, state: str) -> Dict[str, str]:
...
Example output:
{'STREETNUMPREFIX': None,
'STREETNUMBER': '123',
'STREETNUMSUFFIX': None,
'STREETNAME': 'Elm',
'STREETTYPE': 'Ave.'}
So let's say I have a DataFrame with columns address1
, city
, and state
, and I would like to apply the above parser
function across all rows using the value of these three columns as the input, and storing the output for each row as new columns matching to the dictionary returned.
Here is when I have tried so far:
from typing import Dict
from pyspark.sql import functions as F
from pyspark.sql.types import Row, StringType, StructField, StructType
import usaddress
def parser(address1: str, city: str, state: str) -> Dict[str, str]:
unstructured_address = " ".join((address1, city, state))
return parse_unstructured_address(unstructured_address)
def parse_unstructured_address(address: str) -> Dict[str, str]:
tags = usaddress.tag(address_string=address)
return {
"STREETNUMPREFIX": tags[0].get("AddressNumberPrefix", None),
"STREETNUMBER": tags[0].get("AddressNumber", None),
"STREETNUMUNIT": tags[0].get("OccupancyIdentifier", None),
"STREETNUMSUFFIX": tags[0].get("AddressNumberSuffix", None),
"PREDIRECTIONAL": tags[0].get("StreetNamePreDirectional", None),
"STREETNAME": tags[0].get("StreetName", None),
"STREETTYPE": tags[0].get("StreetNamePostType", None),
"POSTDIRECTIONAL": tags[0].get("StreetNamePostDirectional", None),
}
def parse_func(address: str, city: str, state: str) -> Row:
address_parts = parser(address1=address, city=city, state=state)
return Row(*address_parts.keys())(*address_parts.values())
def get_schema(columns: List[str]) ->StructType:
return StructType([StructField(col_name, StringType(), False) for col_name in columns])
input_columns = ["Address1", "CITY", "STATE"]
df = spark.createDataFrame([("123 Main St.", "Cleveland", "OH"), ("57 Heinz St.", "Columbus", "OH")], input_columns)
parsed_columns = ["STREETNUMPREFIX", "STREETNUMBER", "STREETNUMSUFFIX", "STREETNAME", "STREETTYPE"]
out_columns = input_columns + parsed_columns
output_schema = get_schema(out_columns)
parse_udf = F.udf(parse_func, output_schema)
df = df.withColumn("Output", F.explode(F.array(parse_udf(df["Address1"], df["CITY"], df["STATE"]))))
display(df)
The above has only resulted in strange null pointer exceptions that tell me nothing about how to fix things:
SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 4 times, most recent failure: Lost task 0.3 in stage 9.0 (TID 86, 172.18.237.92, executor 1): java.lang.NullPointerException
at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_1$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$11(EvalPythonExec.scala:134)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:731)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:187)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
at org.apache.spark.scheduler.Task.run(Task.scala:117)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$9(Executor.scala:657)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Here is my really simple example for the udf usage.
from pyspark.sql.functions import *
from pyspark.sql.types import *
def cal(a: int, b: int) -> [int, int]:
return [a+b, a*b]
cal = udf(cal, ArrayType(StringType()))
df.select('A', 'B', *[cal('A', 'B')[i] for i in range(0, 2)]) \
.toDF('A', 'B', 'Add', 'Muptiple').show()
+---+---+---+--------+
| A| B|Add|Muptiple|
+---+---+---+--------+
| 1| 2| 3| 2|
| 2| 4| 6| 8|
| 3| 6| 9| 18|
+---+---+---+--------+
I have checked your code and found this.
def get_schema(columns: [str]) -> StructType:
return StructType([StructField(col_name, StringType(), False) for col_name in columns])
You did not allow the null
value for all columns but there is it. So the error comes. I'd recommend you to change False -> True
of the nullable, then it will work.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With