Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark collect_list and limit resulting list

I have a dataframe of the following format:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

What I want to do is group the dataframe by the name, collect the list and limit the size of the list.

This is how i group by the name and collect the list:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

The resuling dataframe is something like:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

What I want to do is limit the size of the produced lists for each key. I' ve tried multiple ways to do that but had no success. I've already seen some posts that suggest 3rd party solutions but I want to avoid that. Is there a way?

like image 820
pirox22 Avatar asked Sep 23 '18 15:09

pirox22


1 Answers

So while a UDF does what you need, if you're looking for a more performant way that is also memory sensitive, the way of doing this would be to write a UDAF. Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark. However you can use their internal APIs to build on the internal functions to do what you need.

Here is an implementation for collect_list_limit that is mostly a copy past of Spark's internal CollectList AggregateFunction. I would just extend it but its a case class. Really all that's needed is to override update and merge methods to respect a passed in limit:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectListLimit using the provided expressions:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

Edit:

Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup. If you have an existing context then this should work to add it with reflection:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
like image 89
user1084563 Avatar answered Sep 30 '22 02:09

user1084563