SPARK_VERSION = 2.2.0
I ran into an interesting issue when trying to do a filter
on a dataframe that has columns that were added using a UDF. I am able to replicate the problem with a smaller set of data.
Given the dummy case classes:
case class Info(number: Int, color: String)
case class Record(name: String, infos: Seq[Info])
and the following data:
val blue = Info(1, "blue")
val black = Info(2, "black")
val yellow = Info(3, "yellow")
val orange = Info(4, "orange")
val white = Info(5, "white")
val a = Record("a", Seq(blue, black, white))
val a2 = Record("a", Seq(yellow, white, orange))
val b = Record("b", Seq(blue, black))
val c = Record("c", Seq(white, orange))
val d = Record("d", Seq(orange, black))
do the following...
Create two dataframes (we will call them left and right)
val left = Seq(a, b).toDF
val right = Seq(a2, c, d).toDF
Join those dataframes using a full_outer
join, and take only what is on the right side
val rightOnlyInfos = left.alias("l")
.join(right.alias("r"), Seq("name"), "full_outer")
.filter("l.infos is null")
.select($"name", $"r.infos".as("r_infos"))
This results in the following:
rightOnlyInfos.show(false)
+----+-----------------------+
|name|r_infos |
+----+-----------------------+
|c |[[5,white], [4,orange]]|
|d |[[4,orange], [2,black]]|
+----+-----------------------+
Using the following udf, add a new column that is a boolean and represents whether or not one of the r_infos
contains the color black
def hasBlack = (s: Seq[Row]) => {
s.exists{ case Row(num: Int, color: String) =>
color == "black"
}
}
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", udf(hasBlack).apply($"r_infos"))
This is where I am seeing problems now. If I do the following, I get no errors:
joinedBreakdown.show(false)
and it results (like expected) in:
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|c |[[5,white], [4,orange]]|false |
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
and the schema
joinedBreakdown.printSchema
shows
root
|-- name: string (nullable = true)
|-- r_infos: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- number: integer (nullable = false)
| | |-- color: string (nullable = true)
|-- has_black: boolean (nullable = true)
However, when I try to filter by that results, I get an error:
joinedBreakdown.filter("has_black == true").show(false)
With the following error:
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$hasBlack$1: (array<struct<number:int,color:string>>) => boolean)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1075)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:411)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$canFilterOutNull(joins.scala:127)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$rightHasNonNullPredicate$lzycompute$1$1.apply(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$rightHasNonNullPredicate$lzycompute$1$1.apply(joins.scala:138)
at scala.collection.LinearSeqOptimized$class.exists(LinearSeqOptimized.scala:93)
at scala.collection.immutable.List.exists(List.scala:84)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.rightHasNonNullPredicate$lzycompute$1(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.rightHasNonNullPredicate$1(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$buildNewJoinType(joins.scala:145)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:152)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:150)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.apply(joins.scala:150)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.apply(joins.scala:116)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:85)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:82)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:82)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74)
at scala.collection.immutable.List.foreach(List.scala:381)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:78)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:78)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan$lzycompute(QueryExecution.scala:84)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan(QueryExecution.scala:80)
at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:89)
at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:89)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:2832)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2153)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2366)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:245)
at org.apache.spark.sql.Dataset.show(Dataset.scala:646)
at org.apache.spark.sql.Dataset.show(Dataset.scala:623)
... 58 elided
Caused by: java.lang.NullPointerException
at $anonfun$hasBlack$1.apply(<console>:41)
at $anonfun$hasBlack$1.apply(<console>:40)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:92)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:91)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1072)
... 114 more
EDIT: opened up a jira issue. Pasting here for tracking purposes: https://issues.apache.org/jira/browse/SPARK-22942
This answer doesn't address why the issue exists but solutions that I have found to work around it.
I have run into a problem exactly like this. I'm not sure of the cause but I have two workarounds that work for me. Someone much smarter than me will probably be able to explain it all to you but here are my solutions to the problem.
FIRST SOLUTION
Spark is acting like the column doesn't exist yet. Probably because of some kind of filter push-down. Force Spark to cache the result prior to filtering. This makes the column "exist".
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", hasBlack($"r_infos")).cache()
println(joinedBreakdown.count()) //This will force cache the results from after the UDF has been applied.
joinedBreakdown.filter("has_black == true").show(false)
joinedBreakdown.filter("has_black == true").explain
OUTPUT
2
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
== Physical Plan ==
*Filter (has_black#112632 = true)
+- InMemoryTableScan [name#112622, r_infos#112628, has_black#112632], [(has_black#112632 = true)]
+- InMemoryRelation [name#112622, r_infos#112628, has_black#112632], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
+- *Project [coalesce(name#112606, name#112614) AS name#112622, infos#112615 AS r_infos#112628, UDF(infos#112615) AS has_black#112632]
+- *Filter isnull(infos#112607)
+- SortMergeJoin [name#112606], [name#112614], FullOuter
:- *Sort [name#112606 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(name#112606, 200)
: +- LocalTableScan [name#112606, infos#112607]
+- *Sort [name#112614 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(name#112614, 200)
+- LocalTableScan [name#112614, infos#112615]
SECOND SOLUTION
No idea why this one works, but do the same thing that you did except put a try/catch in the UDF. And before I get yelled at for it, please know that using try/catch for control flow is an anit-pattern. To learn more I recommend this question and answer. NOTE: I edited your UDF slightly to make it look like something that I am more familiar with.
def hasBlack = udf((s: Seq[Row]) => {
try{
s.exists{ case Row(num: Int, color: String) =>
color == "black"
}
} catch {
case ex: Exception => false
}
})
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", hasBlack($"r_infos"))
joinedBreakdown.filter("has_black == true").explain
joinedBreakdown.filter("has_black == true").show(false)
OUTPUT
== Physical Plan ==
*Project [coalesce(name#112565, name#112573) AS name#112581, infos#112574 AS r_infos#112587, UDF(infos#112574) AS has_black#112591]
+- *Filter isnull(infos#112566)
+- *BroadcastHashJoin [name#112565], [name#112573], RightOuter, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(ArrayBuffer(input[0, string, false]))
: +- *Filter isnotnull(name#112565)
: +- LocalTableScan [name#112565, infos#112566]
+- *Filter (UDF(infos#112574) = true)
+- LocalTableScan [name#112573, infos#112574]
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
You can see that the query plans are different due to the fact that I am forcing the application of the UDF prior to the filter.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With