Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to filter based on array value in PySpark?

My Schema:

|-- Canonical_URL: string (nullable = true)
 |-- Certifications: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Certification_Authority: string (nullable = true)
 |    |    |-- End: string (nullable = true)
 |    |    |-- License: string (nullable = true)
 |    |    |-- Start: string (nullable = true)
 |    |    |-- Title: string (nullable = true)
 |-- CompanyId: string (nullable = true)
 |-- Country: string (nullable = true)
|-- vendorTags: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- score: double (nullable = true)
 |    |    |-- vendor: string (nullable = true)

I tried the below query to select nested fields from vendorTags

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts")

How can I query the nested fields in where clause like below in PySpark

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'")

or

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.score > 123.123456")

something like this..

I tried the above queries only to get the below error

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'")
16/03/15 13:16:02 INFO ParseDriver: Parsing command: select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'
16/03/15 13:16:03 INFO ParseDriver: Parse Completed
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/spark/python/pyspark/sql/context.py", line 583, in sql
    return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
  File "/usr/lib/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py", line 813, in __call__
  File "/usr/lib/spark/python/pyspark/sql/utils.py", line 51, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.AnalysisException: u"cannot resolve '(vendorTags.vendor = cast(alpha as double))' due to data type mismatch: differing types in '(vendorTags.vendor = cast(alpha as double))' (array<string> and double).; line 1 pos 71"
like image 502
Suhas Chandramouli Avatar asked Mar 15 '16 14:03

Suhas Chandramouli


People also ask

Which method is used for filter the DataFrame value in PySpark?

In this article, we are going to see where filter in PySpark Dataframe. Where() is a method used to filter the rows from DataFrame based on the given condition. The where() method is an alias for the filter() method. Both these methods operate exactly the same.

How do you filter out rows in PySpark?

PySpark filter() function is used to filter the rows from RDD/DataFrame based on the given condition or SQL expression, you can also use where() clause instead of the filter() if you are coming from an SQL background, both these functions operate exactly the same.

How do I use ISIN function in PySpark?

PySpark isin() or IN operator is used to check/filter if the DataFrame values are exists/contains in the list of values. isin() is a function of Column class which returns a boolean value True if the value of the expression is contained by the evaluated values of the arguments.

What is withColumn in PySpark?

In PySpark, the withColumn() function is widely used and defined as the transformation function of the DataFrame which is further used to change the value, convert the datatype of an existing column, create the new column etc.


2 Answers

For equality based queries you can use array_contains:

df = sc.parallelize([(1, [1, 2, 3]), (2, [4, 5, 6])]).toDF(["k", "v"])
df.createOrReplaceTempView("df")

# With SQL
sqlContext.sql("SELECT * FROM df WHERE array_contains(v, 1)")

# With DSL
from pyspark.sql.functions import array_contains
df.where(array_contains("v", 1))

If you want to use more complex predicates you'll have to either explode or use an UDF, for example something like this:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf 

def exists(f):
    return udf(lambda xs: any(f(x) for x in xs), BooleanType())

df.where(exists(lambda x: x > 3)("v"))

In Spark 2.4. or later it is also possible to use higher order functions

from pyspark.sql.functions import expr

df.where(expr("""aggregate(
    transform(v, x -> x > 3),
    false, 
    (x, y) -> x or y
)"""))

or

df.where(expr("""
    exists(v, x -> x > 3)
"""))

Python wrappers should be available in 3.1 (SPARK-30681).

like image 140
zero323 Avatar answered Sep 18 '22 13:09

zero323


In spark 2.4 you can filter array values using filter function in sql API.

https://spark.apache.org/docs/2.4.0/api/sql/index.html#filter

Here's example in pyspark. In the example we filter out all array values which are empty strings:

df = df.withColumn("ArrayColumn", expr("filter(ArrayColumn, x -> x != '')"))
like image 38
Jack Avatar answered Sep 20 '22 13:09

Jack