Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark - pass full row to a udf and then get column name inside udf

I am using Spark with Scala and want to pass the entire row to udf and select for each column name and column value in side udf. How can I do this?

I am trying following -

inputDataDF.withColumn("errorField", mapCategory(ruleForNullValidation) (col(_*)))

def mapCategory(categories: Map[String, Boolean]) = {
  udf((input:Row) =>  //write a recursive function to check if each row is in categories if yes check for null if null then false, repeat this for all columns and then combine results)   
})
like image 842
user1122 Avatar asked Mar 06 '23 03:03

user1122


2 Answers

In Spark 1.6 you can use Row as external type and struct as expression. as expression. Column name can be fetched from the schema. For example:

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{col, struct}

val df = Seq((1, 2, 3)).toDF("a", "b", "c")
val f = udf((row: Row) => row.schema.fieldNames)
df.select(f(struct(df.columns map col: _*))).show

// +-----------------------------------------------------------------------------+
// |UDF(named_struct(NamePlaceholder, a, NamePlaceholder, b, NamePlaceholder, c))|
// +-----------------------------------------------------------------------------+
// |                                                                    [a, b, c]|
// +-----------------------------------------------------------------------------+

Values can be accessed by name using Row.getAs method.

like image 89
Alper t. Turker Avatar answered Mar 10 '23 10:03

Alper t. Turker


Here is a simple working example:

Input Data:

+-----+---+--------+
| NAME|AGE|CATEGORY|
+-----+---+--------+
|  RIO| 35|     FIN|
|  TOM| 90|     ACC|
|KEVIN| 32|        |
| STEF| 22|     OPS|
+-----+---+--------+

//Define category list and UDF

val categoryList = List("FIN","ACC")    
def mapCategoryUDF(ls: List[String]) = udf[Boolean,Row]((x: Row) => if (!ls.contains(x.getAs("CATEGORY"))) false else true)

import org.apache.spark.sql.functions.{struct}
df.withColumn("errorField",mapCategoryUDF(categoryList)(struct("*"))).show()

Result should look like this:

+-----+---+--------+----------+
| NAME|AGE|CATEGORY|errorField|
+-----+---+--------+----------+
|  RIO| 35|     FIN|      true|
|  TOM| 90|     ACC|      true|
|KEVIN| 32|        |     false|
| STEF| 22|     OPS|     false|
+-----+---+--------+----------+

Hope this helps!!

like image 31
1pluszara Avatar answered Mar 10 '23 11:03

1pluszara