Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TypeError: Column is not iterable - How to iterate over ArrayType()?

Consider the following DataFrame:

+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[john, sam, jane]      |
|pet   |[whiskers, rover, fido]|
+------+-----------------------+

Which can be created with the following code:

import pyspark.sql.functions as f
data = [
    ('person', ['john', 'sam', 'jane']),
    ('pet', ['whiskers', 'rover', 'fido'])
]

df = sqlCtx.createDataFrame(data, ["type", "names"])
df.show(truncate=False)

Is there a way to directly modify the ArrayType() column "names" by applying a function to each element, without using a udf?

For example, suppose I wanted to apply the function foo to the "names" column. (I will use the example where foo is str.upper just for illustrative purposes, but my question is regarding any valid function that can be applied to the elements of an iterable.)

foo = lambda x: x.upper()  # defining it as str.upper as an example
df.withColumn('X', [foo(x) for x in f.col("names")]).show()

TypeError: Column is not iterable

I could do this using a udf:

foo_udf = f.udf(lambda row: [foo(x) for x in row], ArrayType(StringType()))
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
#+------+-----------------------+
#|type  |names                  |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE]      |
#|pet   |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+

In this specific example, I could avoid the udf by exploding the column, call pyspark.sql.functions.upper(), and then groupBy and collect_list:

df.select('type', f.explode('names').alias('name'))\
    .withColumn('name', f.upper(f.col('name')))\
    .groupBy('type')\
    .agg(f.collect_list('name').alias('names'))\
    .show(truncate=False)
#+------+-----------------------+
#|type  |names                  |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE]      |
#|pet   |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+

But this is a lot of code to do something simple. Is there is a more direct way to iterate over the elements of an ArrayType() using spark-dataframe functions?

like image 486
pault Avatar asked Feb 26 '18 16:02

pault


2 Answers

In Spark < 2.4 you can use an user defined function:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, DataType, StringType

def transform(f, t=StringType()):
    if not isinstance(t, DataType):
       raise TypeError("Invalid type {}".format(type(t)))
    @udf(ArrayType(t))
    def _(xs):
        if xs is not None:
            return [f(x) for x in xs]
    return _

foo_udf = transform(str.upper)

df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

Considering high cost of explode + collect_list idiom, this approach is almost exclusively preferred, despite its intrinsic cost.

In Spark 2.4 or later you can use transform* with upper (see SPARK-23909):

from pyspark.sql.functions import expr

df.withColumn(
    'names', expr('transform(names, x -> upper(x))')
).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

It is also possible to use pandas_udf

from pyspark.sql.functions import pandas_udf, PandasUDFType

def transform_pandas(f, t=StringType()):
    if not isinstance(t, DataType):
       raise TypeError("Invalid type {}".format(type(t)))
    @pandas_udf(ArrayType(t), PandasUDFType.SCALAR)
    def _(xs):
        return xs.apply(lambda xs: [f(x) for x in xs] if xs is not None else xs)
    return _

foo_udf_pandas = transform_pandas(str.upper)

df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+

although only the latest Arrow / PySpark combinations support handling ArrayType columns (SPARK-24259, SPARK-21187). Nonetheless this option should be more efficient than standard UDF (especially with a lower serde overhead) while supporting arbitrary Python functions.


* A number of other higher order functions are also supported, including, but not limited to filter and aggregate. See for example

  • Querying Spark SQL DataFrame with complex types
  • How to slice and sum elements of array column?
  • Filter array column content
  • Spark Scala row-wise average by handling null.
  • How to use transform higher-order function?.
like image 108
10465355 Avatar answered Sep 27 '22 22:09

10465355


Yes you can do it by converting it to RDD and then back to DF.

>>> df.show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[john, sam, jane]      |
|pet   |[whiskers, rover, fido]|
+------+-----------------------+

>>> df.rdd.mapValues(lambda x: [y.upper() for y in x]).toDF(["type","names"]).show(truncate=False)
+------+-----------------------+
|type  |names                  |
+------+-----------------------+
|person|[JOHN, SAM, JANE]      |
|pet   |[WHISKERS, ROVER, FIDO]|
+------+-----------------------+
like image 27
Bala Avatar answered Sep 27 '22 20:09

Bala