Skip to content

Commit

Permalink
Optimize Flow.combine
Browse files Browse the repository at this point in the history
    * Get rid of two code paths
    * Get rid of accidental O(N^2) where N is the number of flows
    * Get rid of select that hits performance hard

Fixes #2296
  • Loading branch information
qwwdfsad committed Oct 15, 2020
1 parent ec9d084 commit bb167b3
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 148 deletions.
26 changes: 0 additions & 26 deletions benchmarks/build.gradle.kts
Expand Up @@ -31,33 +31,7 @@ tasks.named<KotlinCompile>("compileJmhKotlin") {
}
}

/*
* Due to a bug in the inliner it sometimes does not remove inlined symbols (that are later renamed) from unused code paths,
* and it breaks JMH that tries to post-process these symbols and fails because they are renamed.
*/
val removeRedundantFiles by tasks.registering(Delete::class) {
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$buildHistoOnScore\$1\$\$special\$\$inlined\$filter\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$nBlanks\$1\$\$special\$\$inlined\$map\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$score2\$1\$\$special\$\$inlined\$map\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$bonusForDoubleLetter\$1\$\$special\$\$inlined\$map\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$nBlanks\$1\$\$special\$\$inlined\$map\$1\$2\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$bonusForDoubleLetter\$1\$\$special\$\$inlined\$map\$1\$2\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$score2\$1\$\$special\$\$inlined\$map\$1\$2\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOptKt\$\$special\$\$inlined\$collect\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOptKt\$\$special\$\$inlined\$collect\$2\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleOpt\$play\$histoOfLetters\$1\$\$special\$\$inlined\$fold\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleBase\$play\$buildHistoOnScore\$1\$\$special\$\$inlined\$filter\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/FlowPlaysScrabbleBase\$play\$histoOfLetters\$1\$\$special\$\$inlined\$fold\$1\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/scrabble/SaneFlowPlaysScrabble\$play\$buildHistoOnScore\$1\$\$special\$\$inlined\$filter\$1\$1.class")

// Primes
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/misc/Numbers\$\$special\$\$inlined\$filter\$1\$2\$1.class")
delete("$buildDir/classes/kotlin/jmh/benchmarks/flow/misc/Numbers\$\$special\$\$inlined\$filter\$1\$1.class")
}

tasks.named("jmhRunBytecodeGenerator") {
dependsOn(removeRedundantFiles)
}

// It is better to use the following to run benchmarks, otherwise you may get unexpected errors:
// ./gradlew --no-daemon cleanJmhJar jmh -Pjmh="MyBenchmark"
Expand Down
Expand Up @@ -12,25 +12,23 @@ import java.util.concurrent.*
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
open class CombineBenchmark {
open class CombineFlowsBenchmarVolatilek {

@Benchmark
fun measure10() = measure(10)
@Param("10", "100", "1000")
private var size = 10

@Benchmark
fun measure100() = measure(100)
fun combine() = runBlocking {
combine((1 until size).map { flowOf(it) }) { a -> a}.collect()
}

@Benchmark
fun measure1000() = measure(1000)

fun measure(size: Int) = runBlocking {
val flowList = (1..size).map { flowOf(it) }
val listFlow = combine(flowList) { it.toList() }

listFlow.collect {
}
fun combineTransform() = runBlocking {
val list = (1 until size).map { flowOf(it) }.toList()
combineTransform((1 until size).map { flowOf(it) }) { emit(it) }.collect()
}
}

@@ -0,0 +1,47 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.internal.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*

@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
open class CombineTwoFlowsBenchmark {

@Param("100", "100000", "1000000")
private var size = 100000

@Benchmark
fun combinePlain() = runBlocking {
val flow = (1 until size.toLong()).asFlow()
flow.combine(flow) { a, b -> a + b }.collect()
}

@Benchmark
fun combineTransform() = runBlocking {
val flow = (1 until size.toLong()).asFlow()
flow.combineTransform(flow) { a, b -> emit(a + b) }.collect()
}

@Benchmark
fun combineVararg() = runBlocking {
val flow = (1 until size.toLong()).asFlow()
combine(listOf(flow, flow)) { arr -> arr[0] + arr[1] }.collect()
}

@Benchmark
fun combineTransformVararg() = runBlocking {
val flow = (1 until size.toLong()).asFlow()
combineTransform(listOf(flow, flow)) { arr -> emit(arr[0] + arr[1]) }.collect()
}
}
Expand Up @@ -137,14 +137,6 @@ internal abstract class AbstractSendChannel<E>(
return sendSuspend(element)
}

internal suspend fun sendFair(element: E) {
if (offerInternal(element) === OFFER_SUCCESS) {
yield() // Works only on fast path to properly work in sequential use-cases
return
}
return sendSuspend(element)
}

public final override fun offer(element: E): Boolean {
val result = offerInternal(element)
return when {
Expand Down
Expand Up @@ -34,9 +34,4 @@ internal open class ChannelCoroutine<E>(
_channel.cancel(exception) // cancel the channel
cancelCoroutine(exception) // cancel the job
}

@Suppress("UNCHECKED_CAST")
suspend fun sendFair(element: E) {
(_channel as AbstractSendChannel<E>).sendFair(element)
}
}
127 changes: 42 additions & 85 deletions kotlinx-coroutines-core/common/src/flow/internal/Combine.kt
Expand Up @@ -9,107 +9,51 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.selects.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*

internal fun getNull(): Symbol = NULL // Workaround for JS BE bug

internal suspend fun <T1, T2, R> FlowCollector<R>.combineTransformInternal(
first: Flow<T1>, second: Flow<T2>,
transform: suspend FlowCollector<R>.(a: T1, b: T2) -> Unit
) {
coroutineScope {
val firstChannel = asFairChannel(first)
val secondChannel = asFairChannel(second)
var firstValue: Any? = null
var secondValue: Any? = null
var firstIsClosed = false
var secondIsClosed = false
while (!firstIsClosed || !secondIsClosed) {
select<Unit> {
onReceive(firstIsClosed, firstChannel, { firstIsClosed = true }) { value ->
firstValue = value
if (secondValue !== null) {
transform(getNull().unbox(firstValue), getNull().unbox(secondValue) as T2)
}
}

onReceive(secondIsClosed, secondChannel, { secondIsClosed = true }) { value ->
secondValue = value
if (firstValue !== null) {
transform(getNull().unbox(firstValue) as T1, getNull().unbox(secondValue) as T2)
}
}
}
}
}
}

@PublishedApi
internal suspend fun <R, T> FlowCollector<R>.combineInternal(
flows: Array<out Flow<T>>,
arrayFactory: () -> Array<T?>,
transform: suspend FlowCollector<R>.(Array<T>) -> Unit
): Unit = coroutineScope {
): Unit = flowScope { // flow scope so any cancellation within the source flow will cancel the whole scope
val size = flows.size
val channels = Array(size) { asFairChannel(flows[it]) }
val latestValues = arrayOfNulls<Any?>(size)
val latestValues = Array<Any?>(size) { NULL }
val isClosed = Array(size) { false }
var nonClosed = size
var remainingNulls = size
// See flow.combine(other) for explanation of the logic
// Reuse receive blocks to avoid allocations on each iteration
val onReceiveBlocks = Array<suspend (Any?) -> Unit>(size) { i ->
{ value ->
if (value === null) {
isClosed[i] = true;
--nonClosed
}
else {
if (latestValues[i] == null) --remainingNulls
latestValues[i] = value
if (remainingNulls == 0) {
val arguments = arrayFactory()
for (index in 0 until size) {
arguments[index] = NULL.unbox(latestValues[index])
val resultChannel = Channel<Array<T>>(Channel.CONFLATED)
val nonClosed = LocalAtomicInt(size)
val remainingAbsentValues = LocalAtomicInt(size)
for (i in 0 until size) {
// Coroutine per flow that keeps track of its value and sends result to downstream
launch {
try {
flows[i].collect { value ->
val previous = latestValues[i]
latestValues[i] = value
if (previous === NULL) remainingAbsentValues.decrementAndGet()
if (remainingAbsentValues.value == 0) {
val results = arrayFactory()
for (index in 0 until size) {
results[index] = NULL.unbox(latestValues[index])
}
// NB: here actually "stale" array can overwrite a fresh one and break linearizability
resultChannel.send(results as Array<T>)
}
transform(arguments as Array<T>)
yield() // Emulate fairness for backward compatibility
}
} finally {
isClosed[i] = true
// Close the channel when there is no more flows
if (nonClosed.decrementAndGet() == 0) {
resultChannel.close()
}
}
}
}

while (nonClosed != 0) {
select<Unit> {
for (i in 0 until size) {
if (isClosed[i]) continue
channels[i].onReceiveOrNull(onReceiveBlocks[i])
}
}
}
}

private inline fun SelectBuilder<Unit>.onReceive(
isClosed: Boolean,
channel: ReceiveChannel<Any>,
crossinline onClosed: () -> Unit,
noinline onReceive: suspend (value: Any) -> Unit
) {
if (isClosed) return
@Suppress("DEPRECATION")
channel.onReceiveOrNull {
// TODO onReceiveOrClosed when boxing issues are fixed
if (it === null) onClosed()
else onReceive(it)
}
}

// Channel has any type due to onReceiveOrNull. This will be fixed after receiveOrClosed
private fun CoroutineScope.asFairChannel(flow: Flow<*>): ReceiveChannel<Any> = produce {
val channel = channel as ChannelCoroutine<Any>
flow.collect { value ->
return@collect channel.sendFair(value ?: NULL)
resultChannel.consumeEach {
transform(it)
}
}

Expand All @@ -131,12 +75,25 @@ internal fun <T1, T2, R> zipImpl(flow: Flow<T1>, flow2: Flow<T2>, transform: sus
val collectJob = Job()
val scopeJob = currentCoroutineContext()[Job]!!
(second as SendChannel<*>).invokeOnClose {
// Optimization to avoid AFE allocation when the other flow is done
if (!collectJob.isActive) collectJob.cancel(AbortFlowException(this@unsafeFlow))
}

val newContext = coroutineContext + scopeJob
val cnt = threadContextElements(newContext)
try {
/*
* Non-trivial undispatched (because we are in the right context and there is no structured concurrency)
* hierarchy:
* -Outer coroutineScope that owns the whole zip process
* - First flow is collected by the child of coroutineScope, collectJob.
* So it can be safely cancelled as soon as the second flow is done
* - **But** the downstream MUST NOT be cancelled when the second flow is done,
* so we emit to downstream from coroutineScope job.
* Typically, such hierarchy requires coroutine for collector that communicates
* with coroutines scope via a channel, but it's way too expensive, so
* we are using this trick instead.
*/
withContextUndispatched( coroutineContext + collectJob) {
flow.collect { value ->
val otherValue = second.receiveOrNull() ?: return@collect
Expand Down
13 changes: 6 additions & 7 deletions kotlinx-coroutines-core/common/src/flow/operators/Zip.kt
Expand Up @@ -31,9 +31,7 @@ import kotlinx.coroutines.flow.internal.unsafeFlow as flow
*/
@JvmName("flowCombine")
public fun <T1, T2, R> Flow<T1>.combine(flow: Flow<T2>, transform: suspend (a: T1, b: T2) -> R): Flow<R> = flow {
combineTransformInternal(this@combine, flow) { a, b ->
emit(transform(a, b))
}
combineInternal(arrayOf(this@combine, flow), { arrayOfNulls(2) }, { emit(transform(it[0] as T1, it[1] as T2)) })
}

/**
Expand Down Expand Up @@ -75,10 +73,11 @@ public fun <T1, T2, R> combine(flow: Flow<T1>, flow2: Flow<T2>, transform: suspe
public fun <T1, T2, R> Flow<T1>.combineTransform(
flow: Flow<T2>,
@BuilderInference transform: suspend FlowCollector<R>.(a: T1, b: T2) -> Unit
): Flow<R> = safeFlow {
combineTransformInternal(this@combineTransform, flow) { a, b ->
transform(a, b)
}
): Flow<R> = combineTransform(this, flow) { args: Array<*> ->
transform(
args[0] as T1,
args[1] as T2
)
}

/**
Expand Down
32 changes: 32 additions & 0 deletions kotlinx-coroutines-core/common/src/internal/LocalAtomics.common.kt
@@ -0,0 +1,32 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal

/*
* These are atomics that are used as local variables
* where atomicfu doesn't support its tranformations.
*
* Have `Local` prefix to avoid AFU clashes during star-imports
*/

// In fact, used as @Volatile
internal expect class LocalAtomicRef<T>(value: T) {
fun get(): T
fun set(value: T)
}

internal inline var LocalAtomicRef<Any?>.value
get() = get()
set(value) = set(value)

internal expect class LocalAtomicInt(value: Int) {
fun get(): Int
fun set(value: Int)
fun decrementAndGet(): Int
}

internal inline var LocalAtomicInt.value
get() = get()
set(value) = set(value)
Expand Up @@ -211,16 +211,16 @@ abstract class CombineTestBase : TestBase() {
hang { expect(3) }
}

val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach { expect(2) }
val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach { expectUnreached() }
assertFailsWith<CancellationException>(flow)
finish(4)
finish(2)
}

@Test
fun testCancellationExceptionDownstream() = runTest {
val f1 = flow {
emit(1)
expect(2)
expect(1)
hang { expect(5) }
}
val f2 = flow {
Expand All @@ -230,7 +230,7 @@ abstract class CombineTestBase : TestBase() {
}

val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach {
expect(1)
expect(2)
yield()
expect(4)
throw CancellationException("")
Expand Down

0 comments on commit bb167b3

Please sign in to comment.