Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GroupBy operator for Kotlin Flow

I am trying to switch from RxJava to Kotlin Flow. Flow is really impressive. But Is there any operator similar to RxJava's "GroupBy" in kotlin Flow right now?

like image 930
Sundaravel Avatar asked Oct 30 '19 12:10

Sundaravel


People also ask

What is groupBy in Kotlin?

The Kotlin standard library provides extension functions for grouping collection elements. The basic function groupBy() takes a lambda function and returns a Map . In this map, each key is the lambda result and the corresponding value is the List of elements on which this result is returned.

What is the use of kotlin flow?

Stay organized with collections Save and categorize content based on your preferences. In coroutines, a flow is a type that can emit multiple values sequentially, as opposed to suspend functions that return only a single value.

Is kotlin flow collect blocking?

Flow is an idiomatic way in kotlin to publish sequence of values. While the flow itself suspendable, the collector will block the coroutine from proceeding further.


1 Answers

As of Kotlin Coroutines 1.3, the standard library doesn't seem to provide this operator. However, since the design of Flow is such that all operators are extension functions, there is no fundamental distinction between the standard library providing it and you writing your own.

With that in mind, here are some of my ideas on how to approach it.

1. Collect Each Group to a List

If you just need a list of all items for each key, use this simple implementation that emits pairs of (K, List<T>):

fun <T, K> Flow<T>.groupToList(getKey: (T) -> K): Flow<Pair<K, List<T>>> = flow {
    val storage = mutableMapOf<K, MutableList<T>>()
    collect { t -> storage.getOrPut(getKey(t)) { mutableListOf() } += t }
    storage.forEach { (k, ts) -> emit(k to ts) }
}

For this example:

suspend fun main() {
    val input = 1..10
    input.asFlow()
            .groupToList { it % 2 }
            .collect { println(it) }
}

it prints

(1, [1, 3, 5, 7, 9])
(0, [2, 4, 6, 8, 10])

2.a Emit a Flow for Each Group

If you need the full RxJava semantics where you transform the input flow into many output flows (one per distinct key), things get more involved.

Whenever you see a new key in the input, you must emit a new inner flow to the downstream and then, asynchronously, keep pushing more data into it whenever you encounter the same key again.

Here's an implementation that does this:

fun <T, K> Flow<T>.groupBy(getKey: (T) -> K): Flow<Pair<K, Flow<T>>> = flow {
    val storage = mutableMapOf<K, SendChannel<T>>()
    try {
        collect { t ->
            val key = getKey(t)
            storage.getOrPut(key) {
                Channel<T>(32).also { emit(key to it.consumeAsFlow()) }
            }.send(t)
        }
    } finally {
        storage.values.forEach { chan -> chan.close() }
    }
}

It sets up a Channel for each key and exposes the channel to the downstream as a flow.

2.b Concurrently Collect and Reduce Grouped Flows

Since groupBy keeps emitting the data to the inner flows after emitting the flows themselves to the downstream, you have to be very careful with how you collect them.

You must collect all the inner flows concurrently, with no upper limit on the level of concurrency. Otherwise the channels of the flows that are queued for later collection will eventually block the sender and you'll end up with a deadlock.

Here is a function that does this properly:

fun <T, K, R> Flow<Pair<K, Flow<T>>>.reducePerKey(
        reduce: suspend Flow<T>.() -> R
): Flow<Pair<K, R>> = flow {
    coroutineScope {
        this@reducePerKey
                .map { (key, flow) -> key to async { flow.reduce() } }
                .toList()
                .forEach { (key, deferred) -> emit(key to deferred.await()) }
    }
}

The map stage launches a coroutine for each inner flow it receives. The coroutine reduces it to the final result.

toList() is a terminal operation that collects the entire upstream flow, launching all the async coroutines in the process. The coroutines start consuming the inner flows even while we're still collecting the main flow. This is essential to prevent a deadlock.

Finally, after all the coroutines have been launched, we start a forEach loop that waits for and emits the final results as they become available.

You can implement almost the same behavior in terms of flatMapMerge:

fun <T, K, R> Flow<Pair<K, Flow<T>>>.reducePerKey(
        reduce: suspend Flow<T>.() -> R
): Flow<Pair<K, R>> = flatMapMerge(Int.MAX_VALUE) { (key, flow) ->
    flow { emit(key to flow.reduce()) }
}

The difference is in the ordering: whereas the first implementation respects the order of appearance of keys in the input, this one doesn't. Both perform similarly.

3. Example

This example groups and sums 40 million integers:

suspend fun main() {
    val input = 1..40_000_000
    input.asFlow()
            .groupBy { it % 100 }
            .reducePerKey { sum { it.toLong() } }
            .collect { println(it) }
}

suspend fun <T> Flow<T>.sum(toLong: suspend (T) -> Long): Long {
    var sum = 0L
    collect { sum += toLong(it) }
    return sum
}

I can successfully run this with -Xmx64m. On my 4-core laptop I'm getting about 4 million items per second.

It is simple to redefine the first solution in terms of the new one like this:

fun <T, K> Flow<T>.groupToList(getKey: (T) -> K): Flow<Pair<K, List<T>>> =
        groupBy(getKey).reducePerKey { toList() }
like image 79
Marko Topolnik Avatar answered Nov 11 '22 11:11

Marko Topolnik