From e710048362179c93fdb2dfb2000eaffe5eadfbf3 Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Tue, 3 Nov 2020 23:04:19 +0300 Subject: [PATCH] SharedFlow: Fix scenario with concurrent emitters and cancellation of subscriber (#2359) * Added a specific test for a problematic scenario. * Added stress test with concurrent emitters and subscribers that come and go. Fixes #2356 --- .../common/src/flow/SharedFlow.kt | 6 ++ .../flow/sharing/SharedFlowScenarioTest.kt | 42 +++++++++ .../jvm/test/flow/SharedFlowStressTest.kt | 87 +++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt diff --git a/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt b/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt index 427041a7bb..feb2749595 100644 --- a/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt @@ -498,6 +498,12 @@ private class SharedFlowImpl( } // Compute new buffer size -> how many values we now actually have after resume val newBufferSize1 = (newBufferEndIndex - head).toInt() + // Note: When nCollectors == 0 we resume ALL queued emitters and we might have resumed more than bufferCapacity, + // and newMinCollectorIndex might pointing the wrong place because of that. The easiest way to fix it is by + // forcing newMinCollectorIndex = newBufferEndIndex. We do not needed to update newBufferSize1 (which could be + // too big), because the only use of newBufferSize1 in the below code is in the minOf(replay, newBufferSize1) + // expression, which coerces values that are too big anyway. + if (nCollectors == 0) newMinCollectorIndex = newBufferEndIndex // Compute new replay size -> limit to replay the number of items we need, take into account that it can only grow var newReplayIndex = maxOf(replayIndex, newBufferEndIndex - minOf(replay, newBufferSize1)) // adjustment for synchronous case with cancelled emitter (NO_VALUE) diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowScenarioTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowScenarioTest.kt index c3eb2dac04..794553b482 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowScenarioTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowScenarioTest.kt @@ -201,6 +201,48 @@ class SharedFlowScenarioTest : TestBase() { emitResumes(e3); expectReplayOf(3) } + @Test + fun testSuspendedConcurrentEmitAndCancelSubscriberReplay1() = + testSharedFlow(MutableSharedFlow(1)) { + val a = subscribe("a"); + emitRightNow(0); expectReplayOf(0) + collect(a, 0) + emitRightNow(1); expectReplayOf(1) + val e2 = emitSuspends(2) // suspends until 1 is collected + val e3 = emitSuspends(3) // suspends until 1 is collected, too + cancel(a) // must resume emitters 2 & 3 + emitResumes(e2) + emitResumes(e3) + expectReplayOf(3) // but replay size is 1 so only 3 should be kept + // Note: originally, SharedFlow was in a broken state here with 3 elements in the buffer + val b = subscribe("b") + collect(b, 3) + emitRightNow(4); expectReplayOf(4) + collect(b, 4) + } + + @Test + fun testSuspendedConcurrentEmitAndCancelSubscriberReplay1ExtraBuffer1() = + testSharedFlow(MutableSharedFlow( replay = 1, extraBufferCapacity = 1)) { + val a = subscribe("a"); + emitRightNow(0); expectReplayOf(0) + collect(a, 0) + emitRightNow(1); expectReplayOf(1) + emitRightNow(2); expectReplayOf(2) + val e3 = emitSuspends(3) // suspends until 1 is collected + val e4 = emitSuspends(4) // suspends until 1 is collected, too + val e5 = emitSuspends(5) // suspends until 1 is collected, too + cancel(a) // must resume emitters 3, 4, 5 + emitResumes(e3) + emitResumes(e4) + emitResumes(e5) + expectReplayOf(5) + val b = subscribe("b") + collect(b, 5) + emitRightNow(6); expectReplayOf(6) + collect(b, 6) + } + private fun testSharedFlow( sharedFlow: MutableSharedFlow, scenario: suspend ScenarioDsl.() -> Unit diff --git a/kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt b/kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt new file mode 100644 index 0000000000..349b7c8121 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt @@ -0,0 +1,87 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.flow + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import org.junit.* +import org.junit.Test +import kotlin.collections.ArrayList +import kotlin.test.* +import kotlin.time.* + +@ExperimentalTime +class SharedFlowStressTest : TestBase() { + private val nProducers = 5 + private val nConsumers = 3 + private val nSeconds = 3 * stressTestMultiplier + + private lateinit var sf: MutableSharedFlow + private lateinit var view: SharedFlow + + @get:Rule + val producerDispatcher = ExecutorRule(nProducers) + @get:Rule + val consumerDispatcher = ExecutorRule(nConsumers) + + private val totalProduced = atomic(0L) + private val totalConsumed = atomic(0L) + + @Test + fun testStressReplay1() = + testStress(1, 0) + + @Test + fun testStressReplay1ExtraBuffer1() = + testStress(1, 1) + + @Test + fun testStressReplay2ExtraBuffer1() = + testStress(2, 1) + + private fun testStress(replay: Int, extraBufferCapacity: Int) = runTest { + sf = MutableSharedFlow(replay, extraBufferCapacity) + view = sf.asSharedFlow() + val jobs = ArrayList() + jobs += List(nProducers) { producerIndex -> + launch(producerDispatcher) { + var cur = producerIndex.toLong() + while (isActive) { + sf.emit(cur) + totalProduced.incrementAndGet() + cur += nProducers + } + } + } + jobs += List(nConsumers) { consumerIndex -> + launch(consumerDispatcher) { + while (isActive) { + view + .dropWhile { it % nConsumers != consumerIndex.toLong() } + .take(1) + .collect { + check(it % nConsumers == consumerIndex.toLong()) + totalConsumed.incrementAndGet() + } + } + } + } + var lastProduced = 0L + var lastConsumed = 0L + for (sec in 1..nSeconds) { + delay(1.seconds) + val produced = totalProduced.value + val consumed = totalConsumed.value + println("$sec sec: produced = $produced; consumed = $consumed") + assertNotEquals(lastProduced, produced) + assertNotEquals(lastConsumed, consumed) + lastProduced = produced + lastConsumed = consumed + } + jobs.forEach { it.cancel() } + jobs.forEach { it.join() } + println("total: produced = ${totalProduced.value}; consumed = ${totalConsumed.value}") + } +} \ No newline at end of file