Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark: How to apply UDF to multiple columns to create multiple new columns?

I have a DataFrame containing several columns I'd like to use as input to a function which will produce multiple outputs per row, with each output going into a new column.

For example, I have a function that takes address values and parses into finer grain parts:

def parser(address1: str, city: str, state: str) -> Dict[str, str]: 
    ...

Example output:

{'STREETNUMPREFIX': None,
 'STREETNUMBER': '123',
 'STREETNUMSUFFIX': None,
 'STREETNAME': 'Elm',
 'STREETTYPE': 'Ave.'}

So let's say I have a DataFrame with columns address1, city, and state, and I would like to apply the above parser function across all rows using the value of these three columns as the input, and storing the output for each row as new columns matching to the dictionary returned.

Here is when I have tried so far:

from typing import Dict

from pyspark.sql import functions as F
from pyspark.sql.types import Row, StringType, StructField, StructType
import usaddress

def parser(address1: str, city: str, state: str) -> Dict[str, str]:
    unstructured_address = " ".join((address1, city, state))
    return parse_unstructured_address(unstructured_address)
  
  
def parse_unstructured_address(address: str) -> Dict[str, str]:
  
    tags = usaddress.tag(address_string=address)
    return {
      "STREETNUMPREFIX": tags[0].get("AddressNumberPrefix", None),
      "STREETNUMBER": tags[0].get("AddressNumber", None),
      "STREETNUMUNIT": tags[0].get("OccupancyIdentifier", None),
      "STREETNUMSUFFIX": tags[0].get("AddressNumberSuffix", None),
      "PREDIRECTIONAL": tags[0].get("StreetNamePreDirectional", None),
      "STREETNAME": tags[0].get("StreetName", None),
      "STREETTYPE": tags[0].get("StreetNamePostType", None),
      "POSTDIRECTIONAL": tags[0].get("StreetNamePostDirectional", None),
    }

def parse_func(address: str, city: str, state: str) -> Row:
    address_parts = parser(address1=address, city=city, state=state)
    return Row(*address_parts.keys())(*address_parts.values())

def get_schema(columns: List[str]) ->StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

input_columns = ["Address1", "CITY", "STATE"]
df = spark.createDataFrame([("123 Main St.", "Cleveland", "OH"), ("57 Heinz St.", "Columbus", "OH")], input_columns)

parsed_columns = ["STREETNUMPREFIX", "STREETNUMBER", "STREETNUMSUFFIX", "STREETNAME", "STREETTYPE"]
out_columns = input_columns + parsed_columns
output_schema = get_schema(out_columns)

parse_udf = F.udf(parse_func, output_schema)

df = df.withColumn("Output", F.explode(F.array(parse_udf(df["Address1"], df["CITY"], df["STATE"]))))
display(df)

The above has only resulted in strange null pointer exceptions that tell me nothing about how to fix things:

SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 4 times, most recent failure: Lost task 0.3 in stage 9.0 (TID 86, 172.18.237.92, executor 1): java.lang.NullPointerException
    at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_1$(Unknown Source)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
    at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$11(EvalPythonExec.scala:134)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:731)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:187)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
    at org.apache.spark.scheduler.Task.run(Task.scala:117)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$9(Executor.scala:657)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
like image 777
James Adams Avatar asked Oct 17 '25 15:10

James Adams


1 Answers

Here is my really simple example for the udf usage.

from pyspark.sql.functions import *
from pyspark.sql.types import *

def cal(a: int, b: int) -> [int, int]:
    return [a+b, a*b]

cal = udf(cal, ArrayType(StringType()))

df.select('A', 'B', *[cal('A', 'B')[i] for i in range(0, 2)]) \
  .toDF('A', 'B', 'Add', 'Muptiple').show()

+---+---+---+--------+
|  A|  B|Add|Muptiple|
+---+---+---+--------+
|  1|  2|  3|       2|
|  2|  4|  6|       8|
|  3|  6|  9|      18|
+---+---+---+--------+

I have checked your code and found this.

def get_schema(columns: [str]) -> StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

You did not allow the null value for all columns but there is it. So the error comes. I'd recommend you to change False -> True of the nullable, then it will work.

like image 197
Lamanus Avatar answered Oct 19 '25 06:10

Lamanus