Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

spark select columns by type

I want to have a function to dynamically select spark Dataframe columns by their datatype.

So far, I have created:

object StructTypeHelpers {
  def selectColumnsByType[T <: DataType](schem: StructType):Seq[String] = {
    schem.filter(_.dataType.isInstanceOf[T]).map(_.name)
  }

}

so that a StructTypeHelpers. selectColumnsByType[StringType](df.schema) should work. However, the compiler is warning me that:

abstract type T is unchecked since it is eliminated by erasure

When trying to use:

import scala.reflect.ClassTag
def selectColumnsByType[T <: DataType: ClassTag](schem: StructType):Seq[String]

it fails with

No ClassTag available for T

How can I get it to work and compile without the warning?

like image 823
Georg Heiler Avatar asked Feb 20 '19 07:02

Georg Heiler


2 Answers

The idea is to filter only the columns that have the type that you want and then do select.

val df  =  Seq(
  (1, 2, "hello")
).toDF("id", "count", "name")

import org.apache.spark.sql.functions.col
def selectByType(colType: DataType, df: DataFrame) = {

  val cols = df.schema.toList
    .filter(x => x.dataType == colType)
    .map(c => col(c.name))
  df.select(cols:_*)

}
val res = selectByType(IntegerType, df)
like image 98
firsni Avatar answered Oct 20 '22 05:10

firsni


A literal answer, helped by How to know if an object is an instance of a TypeTag's type?, would be this:

var x = spark.table(...)

import org.apache.spark.sql.types._
import scala.reflect.{ClassTag, classTag}
def selectColumnsByType[T <: DataType : ClassTag](schema: StructType):Seq[String] = {
  schema.filter(field => classTag[T].runtimeClass.isInstance(field.dataType)).map(_.name)
}

selectColumnsByType[DecimalType](x.schema)

However, this form definitely makes it easier to use:

var x = spark.table(...)

import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import scala.reflect.{ClassTag, classTag}
class DataFrameHelpers(val df: DataFrame) {
  def selectColumnsByType[T <: DataType : ClassTag](): DataFrame = {
    val cols = df.schema.filter(field => classTag[T].runtimeClass.isInstance(field.dataType)).map(field => col(field.name))
    df.select(cols:_*)
  }    
}

implicit def toDataFrameHelpers(df: DataFrame): DataFrameHelpers = new DataFrameHelpers(df)

x = x.selectColumnsByType[DecimalType]()

Note, though, as an earlier answer mentioned -- isInstanceOf isn't really appropriate here, although it is helpful if you want to get all DecimalType columns, regardless of precision. Using the more normal method, you could do the following instead, which also lets you specify multiple types!

var x = spark.table(...)

import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
class DataFrameHelpers(val df: DataFrame) {
  def selectColumnsByType(dt: DataType*): DataFrame = {
    val cols = df.schema.filter(field => dt.exists(_ == field.dataType)).map(field => col(field.name))
    df.select(cols:_*)
  }    
}

implicit def toDataFrameHelpers(df: DataFrame): DataFrameHelpers = new DataFrameHelpers(df)

x = x.selectColumnsByType(ShortType, DecimalType(38,18))
like image 32
wilbur4321 Avatar answered Oct 20 '22 05:10

wilbur4321