All the data types in pyspark.sql.types
are:
__all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
I have to write a UDF (in pyspark) which returns an array of tuples. What do I give the second argument to it which is the return type of the udf method? It would be something on the lines of ArrayType(TupleType())
...
PySpark UDF is a User Defined Function that is used to create a reusable function in Spark. Once UDF created, that can be re-used on multiple DataFrames and SQL (after registering). The default type of the udf() is StringType. You need to handle nulls explicitly otherwise you will see side-effects.
The StructType in PySpark is defined as the collection of the StructField's that further defines the column name, column data type, and boolean to specify if field and metadata can be nullable or not. The StructField in PySpark represents the field in the StructType.
Create PySpark ArrayType You can create an instance of an ArrayType using ArraType() class, This takes arguments valueType and one optional argument valueContainsNull to specify if a value can accept null, by default it takes True. valueType should be a PySpark type that extends DataType class.
pyspark.sql.functions. explode (col)[source] Returns a new row for each element in the given array or map. Uses the default column name col for elements in the array and key and value for elements in the map unless specified otherwise.
There is no such thing as a TupleType
in Spark. Product types are represented as structs
with fields of specific type. For example if you want to return an array of pairs (integer, string) you can use schema like this:
from pyspark.sql.types import * schema = ArrayType(StructType([ StructField("char", StringType(), False), StructField("count", IntegerType(), False) ]))
Example usage:
from pyspark.sql.functions import udf from collections import Counter char_count_udf = udf( lambda s: Counter(s).most_common(), schema ) df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"]) df.select("*", char_count_udf(df["value"])).show(2, False) ## +---+-----+-------------------------+ ## |id |value|PythonUDF#<lambda>(value)| ## +---+-----+-------------------------+ ## |1 |foo |[[o,2], [f,1]] | ## |2 |bar |[[r,1], [a,1], [b,1]] | ## +---+-----+-------------------------+
Stackoverflow keeps directing me to this question, so I guess I'll add some info here.
Returning simple types from UDF:
from pyspark.sql.types import * from pyspark.sql import functions as F def get_df(): d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)] df = sqlContext.createDataFrame(d, ['x', 'y']) return df df = get_df() df.show() # +---+---+ # | x| y| # +---+---+ # |0.0|0.0| # |0.0|3.0| # |1.0|6.0| # |1.0|9.0| # +---+---+ func = udf(lambda x: str(x), StringType()) df = df.withColumn('y_str', func('y')) func = udf(lambda x: int(x), IntegerType()) df = df.withColumn('y_int', func('y')) df.show() # +---+---+-----+-----+ # | x| y|y_str|y_int| # +---+---+-----+-----+ # |0.0|0.0| 0.0| 0| # |0.0|3.0| 3.0| 3| # |1.0|6.0| 6.0| 6| # |1.0|9.0| 9.0| 9| # +---+---+-----+-----+ df.printSchema() # root # |-- x: double (nullable = true) # |-- y: double (nullable = true) # |-- y_str: string (nullable = true) # |-- y_int: integer (nullable = true)
When integers are not enough:
df = get_df() func = udf(lambda x: [0]*int(x), ArrayType(IntegerType())) df = df.withColumn('list', func('y')) func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, MapType(FloatType(), StringType())) df = df.withColumn('map', func('y')) df.show() # +---+---+--------------------+--------------------+ # | x| y| list| map| # +---+---+--------------------+--------------------+ # |0.0|0.0| []| Map()| # |0.0|3.0| [0, 0, 0]|Map(2.0 -> 2, 0.0...| # |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...| # |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...| # +---+---+--------------------+--------------------+ df.printSchema() # root # |-- x: double (nullable = true) # |-- y: double (nullable = true) # |-- list: array (nullable = true) # | |-- element: integer (containsNull = true) # |-- map: map (nullable = true) # | |-- key: float # | |-- value: string (valueContainsNull = true)
Returning complex datatypes from UDF:
df = get_df() df = df.groupBy('x').agg(F.collect_list('y').alias('y[]')) df.show() # +---+----------+ # | x| y[]| # +---+----------+ # |0.0|[0.0, 3.0]| # |1.0|[9.0, 6.0]| # +---+----------+ schema = StructType([ StructField("min", FloatType(), True), StructField("size", IntegerType(), True), StructField("edges", ArrayType(FloatType()), True), StructField("val_to_index", MapType(FloatType(), IntegerType()), True) # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)])) ]) def func(values): mn = min(values) size = len(values) lst = sorted(values)[::-1] val_to_index = {x: i for i, x in enumerate(values)} return (mn, size, lst, val_to_index) func = udf(func, schema) dff = df.select('*', func('y[]').alias('complex_type')) dff.show(10, False) # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+ dff.printSchema() # +---+----------+------------------------------------------------------+ # |x |y[] |complex_type | # +---+----------+------------------------------------------------------+ # |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| # |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| # +---+----------+------------------------------------------------------+
Passing multiple arguments to a UDF:
df = get_df() func = udf(lambda arr: arr[0]*arr[1],FloatType()) df = df.withColumn('x*y', func(F.array('x', 'y'))) # +---+---+---+ # | x| y|x*y| # +---+---+---+ # |0.0|0.0|0.0| # |0.0|3.0|0.0| # |1.0|6.0|6.0| # |1.0|9.0|9.0| # +---+---+---+
The code is purely for demo purposes, all above transformation are available in Spark code and would yield much better performance. As @zero323 in the comment above, UDFs should generally be avoided in pyspark; returning complex types should make you think about simplifying your logic.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With