I've tried to use countDistinct function which should be available in Spark 1.5 according to DataBrick's blog. However, I got the following exception:
Exception in thread "main" org.apache.spark.sql.AnalysisException: undefined function countDistinct;
I've found that on Spark developers' mail list they suggest using count and distinct functions to get the same result which should be produced by countDistinct:
count(distinct <columnName>)
// Instead
countDistinct(<columnName>)
Because I build aggregation expressions dynamically from the list of the names of aggregation functions I'd prefer to don't have any special cases which require different treating.
So, is it possible to unify it by:
registering manually already implemented in Spark CountDistinct function which is probably one from following import:
import org.apache.spark.sql.catalyst.expressions.{CountDistinctFunction, CountDistinct}
or do it in any other way?
EDIT: Example (with removed some local references and unnecessary code):
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, SQLContext, DataFrame}
import org.apache.spark.sql.functions._
import scala.collection.mutable.ListBuffer
class Flattener(sc: SparkContext) {
val sqlContext = new SQLContext(sc)
def flatTable(data: DataFrame, groupField: String): DataFrame = {
val flatteningExpressions = data.columns.zip(TypeRecognizer.getTypes(data)).
flatMap(x => getFlatteningExpressions(x._1, x._2)).toList
data.groupBy(groupField).agg (
expr(s"count($groupField) as groupSize"),
flatteningExpressions:_*
)
}
private def getFlatteningExpressions(fieldName: String, fieldType: DType): List[Column] = {
val aggFuncs = getAggregationFunctons(fieldType)
aggFuncs.map(f => expr(s"$f($fieldName) as ${fieldName}_$f"))
}
private def getAggregationFunctons(fieldType: DType): List[String] = {
val aggFuncs = new ListBuffer[String]()
if(fieldType == DType.NUMERIC) {
aggFuncs += ("avg", "min", "max")
}
if(fieldType == DType.CATEGORY) {
aggFuncs += "countDistinct"
}
aggFuncs.toList
}
}
countDistinct can be used in two different forms:
df.groupBy("A").agg(expr("count(distinct B)")
or
df.groupBy("A").agg(countDistinct("B"))
However, neither of these methods work when you want to use them on the same column with your custom UDAF (implemented as UserDefinedAggregateFunction in Spark 1.5):
// Assume that we have already implemented and registered StdDev UDAF
df.groupBy("A").agg(countDistinct("B"), expr("StdDev(B)"))
// Will cause
Exception in thread "main" org.apache.spark.sql.AnalysisException: StdDev is implemented based on the new Aggregate Function interface and it cannot be used with functions implemented based on the old Aggregate Function interface.;
Due to these limitation it looks that the most reasonable is implementing countDistinct as a UDAF what should allow to treat all functions in the same way as well as use countDistinct along with other UDAFs.
The example implementation can look like this:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class CountDistinct extends UserDefinedAggregateFunction{
override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
}
override def bufferSchema: StructType = StructType(
StructField("items", ArrayType(StringType, true)) :: Nil
)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq[String]()
}
override def deterministic: Boolean = true
override def evaluate(buffer: Row): Any = {
buffer.getSeq[String](0).length
}
override def dataType: DataType = IntegerType
}
Not sure if I really understood your problem, but this is an example for the countDistinct aggregated function:
val values = Array((1, 2), (1, 3), (2, 2), (1, 2))
val myDf = sc.parallelize(values).toDF("id", "foo")
import org.apache.spark.sql.functions.countDistinct
myDf.groupBy('id).agg(countDistinct('foo) as 'distinctFoo) show
/**
+---+-------------------+
| id|COUNT(DISTINCT foo)|
+---+-------------------+
| 1| 2|
| 2| 1|
+---+-------------------+
*/
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With