Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to filter one spark dataframe against another dataframe

I'm trying to filter one dataframe against another:

scala> val df1 = sc.parallelize((1 to 100).map(a=>(s"user $a", a*0.123, a))).toDF("name", "score", "user_id")
scala> val df2 = sc.parallelize(List(2,3,4,5,6)).toDF("valid_id")

Now I want to filter df1 and get back a dataframe that contains all the rows in df1 where user_id is in df2("valid_id"). In other words, I want all the rows in df1 where the user_id is either 2,3,4,5 or 6

scala> df1.select("user_id").filter($"user_id" in df2("valid_id"))
warning: there were 1 deprecation warning(s); re-run with -deprecation for details
org.apache.spark.sql.AnalysisException: resolved attribute(s) valid_id#20 missing from user_id#18 in operator !Filter user_id#18 IN (valid_id#20);  

On the other hand when I try to do a filter against a function, everything looks great:

scala> df1.select("user_id").filter(($"user_id" % 2) === 0)
res1: org.apache.spark.sql.DataFrame = [user_id: int]

Why am I getting this error? Is there something wrong with my syntax?

following comment I have tried to do a left outer join:

scala> df1.show
+-------+------------------+-------+
|   name|             score|user_id|
+-------+------------------+-------+
| user 1|             0.123|      1|
| user 2|             0.246|      2|
| user 3|             0.369|      3|
| user 4|             0.492|      4|
| user 5|             0.615|      5|
| user 6|             0.738|      6|
| user 7|             0.861|      7|
| user 8|             0.984|      8|
| user 9|             1.107|      9|
|user 10|              1.23|     10|
|user 11|             1.353|     11|
|user 12|             1.476|     12|
|user 13|             1.599|     13|
|user 14|             1.722|     14|
|user 15|             1.845|     15|
|user 16|             1.968|     16|
|user 17|             2.091|     17|
|user 18|             2.214|     18|
|user 19|2.3369999999999997|     19|
|user 20|              2.46|     20|
+-------+------------------+-------+
only showing top 20 rows

scala> df2.show
+--------+
|valid_id|
+--------+
|       2|
|       3|
|       4|
|       5|
|       6|
+--------+

scala> df1.join(df2, df1("user_id") === df2("valid_id"))
res6: org.apache.spark.sql.DataFrame = [name: string, score: double, user_id: int, valid_id: int]
scala> res6.collect
res7: Array[org.apache.spark.sql.Row] = Array()

scala> df1.join(df2, df1("user_id") === df2("valid_id"), "left_outer")
res8: org.apache.spark.sql.DataFrame = [name: string, score: double, user_id: int, valid_id: int]
scala> res8.count
res9: Long = 0

I'm running spark 1.5.0 with scala 2.10.5

like image 513
polo Avatar asked Sep 18 '15 23:09

polo


People also ask

How do I add a filter to a DataFrame in Spark?

Spark filter() or where() function is used to filter the rows from DataFrame or Dataset based on the given one or multiple conditions or SQL expression. You can use where() operator instead of the filter if you are coming from SQL background. Both these functions operate exactly the same.

What is anti join in Spark?

Anti Join. An anti join returns values from the left relation that has no match with the right. It is also referred to as a left anti join.


1 Answers

You want a (regular) inner join, not an outer join :)

df1.join(df2, df1("user_id") === df2("valid_id"))
like image 123
Glennie Helles Sindholt Avatar answered Oct 06 '22 01:10

Glennie Helles Sindholt