I want to have a function to dynamically select spark Dataframe columns by their datatype.
So far, I have created:
object StructTypeHelpers {
def selectColumnsByType[T <: DataType](schem: StructType):Seq[String] = {
schem.filter(_.dataType.isInstanceOf[T]).map(_.name)
}
}
so that a StructTypeHelpers. selectColumnsByType[StringType](df.schema)
should work. However, the compiler is warning me that:
abstract type T is unchecked since it is eliminated by erasure
When trying to use:
import scala.reflect.ClassTag
def selectColumnsByType[T <: DataType: ClassTag](schem: StructType):Seq[String]
it fails with
No ClassTag available for T
How can I get it to work and compile without the warning?
The idea is to filter only the columns that have the type that you want and then do select.
val df = Seq(
(1, 2, "hello")
).toDF("id", "count", "name")
import org.apache.spark.sql.functions.col
def selectByType(colType: DataType, df: DataFrame) = {
val cols = df.schema.toList
.filter(x => x.dataType == colType)
.map(c => col(c.name))
df.select(cols:_*)
}
val res = selectByType(IntegerType, df)
A literal answer, helped by How to know if an object is an instance of a TypeTag's type?, would be this:
var x = spark.table(...)
import org.apache.spark.sql.types._
import scala.reflect.{ClassTag, classTag}
def selectColumnsByType[T <: DataType : ClassTag](schema: StructType):Seq[String] = {
schema.filter(field => classTag[T].runtimeClass.isInstance(field.dataType)).map(_.name)
}
selectColumnsByType[DecimalType](x.schema)
However, this form definitely makes it easier to use:
var x = spark.table(...)
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import scala.reflect.{ClassTag, classTag}
class DataFrameHelpers(val df: DataFrame) {
def selectColumnsByType[T <: DataType : ClassTag](): DataFrame = {
val cols = df.schema.filter(field => classTag[T].runtimeClass.isInstance(field.dataType)).map(field => col(field.name))
df.select(cols:_*)
}
}
implicit def toDataFrameHelpers(df: DataFrame): DataFrameHelpers = new DataFrameHelpers(df)
x = x.selectColumnsByType[DecimalType]()
Note, though, as an earlier answer mentioned -- isInstanceOf
isn't really appropriate here, although it is helpful if you want to get all DecimalType
columns, regardless of precision. Using the more normal method, you could do the following instead, which also lets you specify multiple types!
var x = spark.table(...)
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
class DataFrameHelpers(val df: DataFrame) {
def selectColumnsByType(dt: DataType*): DataFrame = {
val cols = df.schema.filter(field => dt.exists(_ == field.dataType)).map(field => col(field.name))
df.select(cols:_*)
}
}
implicit def toDataFrameHelpers(df: DataFrame): DataFrameHelpers = new DataFrameHelpers(df)
x = x.selectColumnsByType(ShortType, DecimalType(38,18))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With