I am working with a pyspark.sql.dataframe.DataFrame
. I would like to filter stack
's rows based on multiple variables, rather than a single one, {val}
. I am working with a Python 2 Jupyter notebook. Presently, I do the following:
stack = hiveContext.sql("""
SELECT *
FROM db.table
WHERE col_1 != ''
""")
stack.show()
+---+-------+-------+---------+
| id| col_1 | . . . | list |
+---+-------+-------+---------+
| 1 | 524 | . . . |[1, 2] |
| 2 | 765 | . . . |[2, 3] |
.
.
.
| 9 | 765 | . . . |[4, 5, 8]|
for i in len(list):
filtered_stack = stack.filter("array_contains(list, {val})".format(val=val.append(list[i])))
(some query on filtered_stack)
How would I rewrite this in Python code to filter rows based on more than one value? i.e. where {val} is equal to some array of one or more elements.
My question is related to: ARRAY_CONTAINS muliple values in hive, however I'm trying to achieve the above in a Python 2 Jupyter notebook.
With Python UDF:
from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
intersect = lambda type: (udf(
lambda x, y: (
list(set(x) & set(y)) if x is not None and y is not None else None),
ArrayType(type)))
df = sc.parallelize([([1, 2, 3], [1, 2]), ([3, 4], [5, 6])]).toDF(["xs", "ys"])
integer_intersect = intersect(IntegerType())
df.select(
integer_intersect("xs", "ys"),
size(integer_intersect("xs", "ys"))).show()
+----------------+----------------------+
|<lambda>(xs, ys)|size(<lambda>(xs, ys))|
+----------------+----------------------+
| [1, 2]| 2|
| []| 0|
+----------------+----------------------+
With literal:
from pyspark.sql.functions import array, lit
df.select(integer_intersect("xs", array(lit(1), lit(5)))).show()
+-------------------------+
|<lambda>(xs, array(1, 5))|
+-------------------------+
| [1]|
| []|
+-------------------------+
or
df.where(size(integer_intersect("xs", array(lit(1), lit(5)))) > 0).show()
+---------+------+
| xs| ys|
+---------+------+
|[1, 2, 3]|[1, 2]|
+---------+------+
Without UDFs
import pyspark.sql.functions as F
vals = {1, 2, 3}
_ = F.array_intersect(
F.col("list"),
F.array([F.lit(i) for i in vals])
)
# This will now give a boolean field for any row with a list which has values in vals
_ = F.size(_) > 0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With