I am trying to implement a custom UDT and be able to reference it from Spark SQL (as explained in the Spark SQL whitepaper, section 4.4.2).
The real example is to have a custom UDT backed by an off-heap data structure using Cap'n Proto, or similar.
For this posting, I have made up a contrived example. I know that I could just use Scala case classes and not have to do any work at all, but that isn't my goal.
For example, I have a Person
containing several attributes and I want to be able to SELECT person.first_name FROM person
. I'm running into the error Can't extract value from person#1
and I'm not sure why.
Here is the full source (also available at https://github.com/andygrove/spark-sql-udt)
package com.theotherandygrove
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
object Example {
def main(arg: Array[String]): Unit = {
val conf = new SparkConf()
.setAppName("Example")
.setMaster("local[*]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
val schema = StructType(List(
StructField("person_id", DataTypes.IntegerType, true),
StructField("person", new MockPersonUDT, true)))
// load initial RDD
val rdd = sc.parallelize(List(
MockPersonImpl(1),
MockPersonImpl(2)
))
// convert to RDD[Row]
val rowRdd = rdd.map(person => Row(person.getAge, person))
// convert to DataFrame (RDD + Schema)
val dataFrame = sqlContext.createDataFrame(rowRdd, schema)
// register as a table
dataFrame.registerTempTable("person")
// selecting the whole object works fine
val results = sqlContext.sql("SELECT person.first_name FROM person WHERE person.age < 100")
val people = results.collect
people.map(row => {
println(row)
})
}
}
trait MockPerson {
def getFirstName: String
def getLastName: String
def getAge: Integer
def getState: String
}
class MockPersonUDT extends UserDefinedType[MockPerson] {
override def sqlType: DataType = StructType(List(
StructField("firstName", StringType, nullable=false),
StructField("lastName", StringType, nullable=false),
StructField("age", IntegerType, nullable=false),
StructField("state", StringType, nullable=false)
))
override def userClass: Class[MockPerson] = classOf[MockPerson]
override def serialize(obj: Any): Any = obj.asInstanceOf[MockPersonImpl].getAge
override def deserialize(datum: Any): MockPerson = MockPersonImpl(datum.asInstanceOf[Integer])
}
@SQLUserDefinedType(udt = classOf[MockPersonUDT])
@SerialVersionUID(123L)
case class MockPersonImpl(n: Integer) extends MockPerson with Serializable {
def getFirstName = "First" + n
def getLastName = "Last" + n
def getAge = n
def getState = "AK"
}
If I simply SELECT person FROM person
then the query works. I just can't reference the attributes in SQL, even though they are defined in the schema.
You get this errors because schema defined by sqlType
is never exposed and is not intended to be accessed directly. It simply provides a way to express a complex data types using native Spark SQL types.
You can access individual attributes using UDFs but first lets show that the internal structure is indeed not exposed:
dataFrame.printSchema
// root
// |-- person_id: integer (nullable = true)
// |-- person: mockperso (nullable = true)
To create UDF we need functions which take as an argument an object of a type represented by a given UDT:
import org.apache.spark.sql.functions.udf
val getFirstName = (person: MockPerson) => person.getFirstName
val getLastName = (person: MockPerson) => person.getLastName
val getAge = (person: MockPerson) => person.getAge
which can be wrapped using udf
function:
val getFirstNameUDF = udf(getFirstName)
val getLastNameUDF = udf(getLastName)
val getAgeUDF = udf(getAge)
dataFrame.select(
getFirstNameUDF($"person").alias("first_name"),
getLastNameUDF($"person").alias("last_name"),
getAgeUDF($"person").alias("age")
).show()
// +----------+---------+---+
// |first_name|last_name|age|
// +----------+---------+---+
// | First1| Last1| 1|
// | First2| Last2| 2|
// +----------+---------+---+
To use these with raw SQL you have register functions through SQLContext
:
sqlContext.udf.register("first_name", getFirstName)
sqlContext.udf.register("last_name", getLastName)
sqlContext.udf.register("age", getAge)
sqlContext.sql("""
SELECT first_name(person) AS first_name, last_name(person) AS last_name
FROM person
WHERE age(person) < 100""").show
// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// | First1| Last1|
// | First2| Last2|
// +----------+---------+
Unfortunately it comes with a price tag attached. First of all every operation requires deserialization. It also substantially limits the ways in which query can be optimized. In particular any join
operation on one of these fields requires a Cartesian product.
In practice if you want to encode a complex structure, which contains attributes that can be expressed using built-in types, it is better to use StructType
:
case class Person(first_name: String, last_name: String, age: Int)
val df = sc.parallelize(
(1 to 2).map(i => (i, Person(s"First$i", s"Last$i", i)))).toDF("id", "person")
df.printSchema
// root
// |-- id: integer (nullable = false)
// |-- person: struct (nullable = true)
// | |-- first_name: string (nullable = true)
// | |-- last_name: string (nullable = true)
// | |-- age: integer (nullable = false)
df
.where($"person.age" < 100)
.select($"person.first_name", $"person.last_name")
.show
// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// | First1| Last1|
// | First2| Last2|
// +----------+---------+
and reserve UDTs for actual types extensions like built-in VectorUDT
or things that can benefit from a specific representation like enumerations.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With