Skip to content

Commit

Permalink
SharedFlow: Fix scenario with concurrent emitters and cancellation of…
Browse files Browse the repository at this point in the history
… subscriber

* Added a specific test for a problematic scenario.
* Added stress test with concurrent emitters and subscribers that come and go.

Fixes #2356
  • Loading branch information
elizarov committed Nov 2, 2020
1 parent 4ea4078 commit 15c95b9
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 1 deletion.
8 changes: 7 additions & 1 deletion kotlinx-coroutines-core/common/src/flow/SharedFlow.kt
Expand Up @@ -497,7 +497,13 @@ private class SharedFlowImpl<T>(
}
}
// Compute new buffer size -> how many values we now actually have after resume
val newBufferSize1 = (newBufferEndIndex - head).toInt()
var newBufferSize1 = (newBufferEndIndex - head).toInt()
// Note: When nCollectors == 0 we resume all queued emitted and might have resumed more that max size of
// the buffer, so here is why we take coerce the resulting size to the buffer capacity
if (nCollectors == 0 && newBufferSize1 > bufferCapacity) {
newMinCollectorIndex += newBufferSize1 - bufferCapacity // adjust minCollectorIndex, too, to skip items
newBufferSize1 = bufferCapacity
}
// Compute new replay size -> limit to replay the number of items we need, take into account that it can only grow
var newReplayIndex = maxOf(replayIndex, newBufferEndIndex - minOf(replay, newBufferSize1))
// adjustment for synchronous case with cancelled emitter (NO_VALUE)
Expand Down
Expand Up @@ -201,6 +201,26 @@ class SharedFlowScenarioTest : TestBase() {
emitResumes(e3); expectReplayOf(3)
}

@Test
fun testSuspendedConcurrentEmitAndCancelSubscriber() =
testSharedFlow<Int>(MutableSharedFlow(1)) {
val a = subscribe("a");
emitRightNow(0); expectReplayOf(0)
collect(a, 0)
emitRightNow(1); expectReplayOf(1)
val e2 = emitSuspends(2) // suspends until 1 is collected
val e3 = emitSuspends(3) // suspends until 1 is collected, too
cancel(a) // must resume emitters 2 & 3
emitResumes(e2)
emitResumes(e3)
expectReplayOf(3) // but replay size is 1 so only 3 should be kept
// Note: originally, SharedFlow was in a broken state here with 3 elements in the buffer
val b = subscribe("b")
collect(b, 3)
emitRightNow(4); expectReplayOf(4)
collect(b, 4)
}

private fun <T> testSharedFlow(
sharedFlow: MutableSharedFlow<T>,
scenario: suspend ScenarioDsl<T>.() -> Unit
Expand Down
77 changes: 77 additions & 0 deletions kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt
@@ -0,0 +1,77 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import org.junit.*
import org.junit.Test
import kotlin.collections.ArrayList
import kotlin.test.*
import kotlin.time.*

@ExperimentalTime
class SharedFlowStressTest : TestBase() {
private val nProducers = 5
private val nConsumers = 3
private val nSeconds = 3 * stressTestMultiplier

private val sf: MutableSharedFlow<Long> = MutableSharedFlow(1)
private val view: SharedFlow<Long> = sf.asSharedFlow()

@get:Rule
val producerDispatcher = ExecutorRule(nProducers)
@get:Rule
val consumerDispatcher = ExecutorRule(nConsumers)

private val totalProduced = atomic(0L)
private val totalConsumed = atomic(0L)

@Test
fun testStress() = runTest {
val jobs = ArrayList<Job>()
jobs += List(nProducers) { producerIndex ->
launch(producerDispatcher) {
var cur = producerIndex.toLong()
while (isActive) {
sf.emit(cur)
totalProduced.incrementAndGet()
cur += nProducers
}
}
}
jobs += List(nConsumers) { consumerIndex ->
launch(consumerDispatcher) {
while (isActive) {
view
.dropWhile { it % nConsumers != consumerIndex.toLong() }
.take(1)
.collect {
check(it % nConsumers == consumerIndex.toLong())
totalConsumed.incrementAndGet()
}
}
}
}
var lastProduced = 0L
var lastConsumed = 0L
for (sec in 1..nSeconds) {
delay(1.seconds)
val produced = totalProduced.value
val consumed = totalConsumed.value
println("$sec sec: produced = $produced; consumed = $consumed")
assertNotEquals(lastProduced, produced)
assertNotEquals(lastConsumed, consumed)
lastProduced = produced
lastConsumed = consumed
}
jobs.forEach { it.cancel() }
jobs.forEach { it.join() }
println("total: produced = ${totalProduced.value}; consumed = ${totalConsumed.value}")
}

private fun showStats(s: String) {
}
}

0 comments on commit 15c95b9

Please sign in to comment.