Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use countDistinct in Scala with Spark?

I've tried to use countDistinct function which should be available in Spark 1.5 according to DataBrick's blog. However, I got the following exception:

Exception in thread "main" org.apache.spark.sql.AnalysisException: undefined function countDistinct;

I've found that on Spark developers' mail list they suggest using count and distinct functions to get the same result which should be produced by countDistinct:

count(distinct <columnName>)
// Instead
countDistinct(<columnName>)

Because I build aggregation expressions dynamically from the list of the names of aggregation functions I'd prefer to don't have any special cases which require different treating.

So, is it possible to unify it by:

  • registering new UDAF which will be an alias for count(distinct columnName)
  • registering manually already implemented in Spark CountDistinct function which is probably one from following import:

    import org.apache.spark.sql.catalyst.expressions.{CountDistinctFunction, CountDistinct}

  • or do it in any other way?

EDIT: Example (with removed some local references and unnecessary code):

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, SQLContext, DataFrame}
import org.apache.spark.sql.functions._

import scala.collection.mutable.ListBuffer


class Flattener(sc: SparkContext) {
  val sqlContext = new SQLContext(sc)

  def flatTable(data: DataFrame, groupField: String): DataFrame = {
    val flatteningExpressions = data.columns.zip(TypeRecognizer.getTypes(data)).
      flatMap(x => getFlatteningExpressions(x._1, x._2)).toList

    data.groupBy(groupField).agg (
      expr(s"count($groupField) as groupSize"),
      flatteningExpressions:_*
    )
  }

  private def getFlatteningExpressions(fieldName: String, fieldType: DType): List[Column] = {
    val aggFuncs = getAggregationFunctons(fieldType)

    aggFuncs.map(f => expr(s"$f($fieldName) as ${fieldName}_$f"))
  }

  private def getAggregationFunctons(fieldType: DType): List[String] = {
    val aggFuncs = new ListBuffer[String]()

    if(fieldType == DType.NUMERIC) {
      aggFuncs += ("avg", "min", "max")
    }

    if(fieldType == DType.CATEGORY) {
      aggFuncs += "countDistinct"
    }

    aggFuncs.toList
  }

}
like image 655
Adam H Avatar asked Nov 03 '15 13:11

Adam H


2 Answers

countDistinct can be used in two different forms:

df.groupBy("A").agg(expr("count(distinct B)")

or

df.groupBy("A").agg(countDistinct("B"))

However, neither of these methods work when you want to use them on the same column with your custom UDAF (implemented as UserDefinedAggregateFunction in Spark 1.5):

// Assume that we have already implemented and registered StdDev UDAF 
df.groupBy("A").agg(countDistinct("B"), expr("StdDev(B)"))

// Will cause
Exception in thread "main" org.apache.spark.sql.AnalysisException: StdDev is implemented based on the new Aggregate Function interface and it cannot be used with functions implemented based on the old Aggregate Function interface.;

Due to these limitation it looks that the most reasonable is implementing countDistinct as a UDAF what should allow to treat all functions in the same way as well as use countDistinct along with other UDAFs.

The example implementation can look like this:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class CountDistinct extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
  }

  override def bufferSchema: StructType = StructType(
      StructField("items", ArrayType(StringType, true)) :: Nil
  )

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
  }

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Seq[String]()
  }

  override def deterministic: Boolean = true

  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[String](0).length
  }

  override def dataType: DataType = IntegerType
}
like image 194
Adam H Avatar answered Sep 28 '22 17:09

Adam H


Not sure if I really understood your problem, but this is an example for the countDistinct aggregated function:

val values = Array((1, 2), (1, 3), (2, 2), (1, 2))
val myDf = sc.parallelize(values).toDF("id", "foo")
import org.apache.spark.sql.functions.countDistinct
myDf.groupBy('id).agg(countDistinct('foo) as 'distinctFoo) show
/**
+---+-------------------+
| id|COUNT(DISTINCT foo)|
+---+-------------------+
|  1|                  2|
|  2|                  1|
+---+-------------------+
*/
like image 42
alghimo Avatar answered Sep 28 '22 17:09

alghimo