Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark schema from case class with correct nullability

For a custom Estimator`s transformSchema method I need to be able to compare the schema of a input data frame to the schema defined in a case class. Usually this could be performed like Generate a Spark StructType / Schema from a case class as outlined below. However, the wrong nullability is used:

The real schema of the df inferred by spark.read.csv().as[MyClass] might look like:

root
 |-- CUSTOMER_ID: integer (nullable = false)

And the case class:

case class MySchema(CUSTOMER_ID: Int)

To compare I use:

val rawSchema = ScalaReflection.schemaFor[MySchema].dataType.asInstanceOf[StructType]
  if (!rawSchema.equals(rawDf.schema))

Unfortunately this always yields false, as the new schema manually inferred from the case class is setting nullable to true (because ja java.Integer actually might be null)

root
 |-- CUSTOMER_ID: integer (nullable = true)

How can I specify nullable = false when creating the schema?

like image 208
Georg Heiler Avatar asked Dec 23 '22 23:12

Georg Heiler


1 Answers

Arguably you're mixing things which don't really belong in the same space. ML Pipelines are inherently dynamic and introducing statically typed objects doesn't really change that.

Moreover schema for a class defined as:

case class MySchema(CUSTOMER_ID: Int)

will have not nullable CUSTOMER_ID. scala.Int is not the same as java.lang.Integer:

scala> import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor

scala> case class MySchema(CUSTOMER_ID: Int)
defined class MySchema

scala> schemaFor[MySchema].dataType
res0: org.apache.spark.sql.types.DataType = StructType(StructField(CUSTOMER_ID,IntegerType,false))

That being said if you want nullable fields Option[Int]:

case class MySchema(CUSTOMER_ID: Option[Int])

and if you want not nullable use Int as above.

Another problem you have here is that for csv every field is nullable by definition and this state is "inherited" by the encoded Dataset. So in practice:

spark.read.csv(...)

will always result in:

root
 |-- CUSTOMER_ID: integer (nullable = true)

and this is why you get schema mismatch. Unfortunately it is not possible to override nullable field for sources which don't enforce nullability constraints, like csv or json.

If having not nullable schema is a hard requirement you could try:

spark.createDataFrame(
  spark.read.csv(...).rdd,
  schemaFor[MySchema].dataType.asInstanceOf[StructType]
).as[MySchema]

This approach is valid only if you know that data is actually null free. Any null value wiil lead to runtime exception.

like image 124
zero323 Avatar answered Mar 08 '23 09:03

zero323