Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filtering dataframe array items based on an external array with intersection

I'm trying to define a way to filter elements from WrappedArrays in DFs. The filter is based on an external list of elements.

Looking for a solutions I found this question. It is very similar, but it seems not to work for me. I'm using Spark 2.4.0. This is my code:

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")


def filterItems(flist: Seq[String]) = udf {
  (recs: Seq[String]) => recs match {
    case null => Seq.empty[String]
    case recs => recs.intersect(flist)
  }}

df.withColumn("filtercol", filterItems(Seq("s", "v"))(col("bar"))).show(5)

My expected result would be:

+---+---------+---------+ 
|foo| bar|filtercol| 
+---+---------+---------+ 
| 1 |[s, v, r]|   [s, v]| 
| 2 |[r, a, v]|      [v]| 
| 3| [s, r, t]|      [s]| 
+---+---------+---------+

But I'm getting this error:

java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
like image 397
pez_betta Avatar asked Jan 26 '26 08:01

pez_betta


1 Answers

You can use the build-in function in Spark 2.4 without too much effort actually:

import org.apache.spark.sql.functions.{array_intersect, array, lit}

val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")

val ar = Seq("s", "v").map(lit(_))
df.withColumn("filtercol", array_intersect($"bar", array(ar:_*))).show

Output:

+---+---------+---------+
|foo|      bar|filtercol|
+---+---------+---------+
|  1|[s, v, r]|   [s, v]|
|  2|[r, a, v]|      [v]|
|  3|[s, r, t]|      [s]|
+---+---------+---------+

The only tricky part is Seq("s", "v").map(lit(_)) which will map each string into lit(i). The intersection function accepts two arrays. The first one is the value of bar column. The second one is created it on the fly with array(ar:_*), which will contain values of lit(i).

like image 102
abiratsis Avatar answered Jan 28 '26 06:01

abiratsis



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!