Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SharedFlow: Fix scenario with concurrent emitters and cancellation of subscriber #2359

Merged
merged 4 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion kotlinx-coroutines-core/common/src/flow/SharedFlow.kt
Expand Up @@ -497,7 +497,13 @@ private class SharedFlowImpl<T>(
}
}
// Compute new buffer size -> how many values we now actually have after resume
val newBufferSize1 = (newBufferEndIndex - head).toInt()
var newBufferSize1 = (newBufferEndIndex - head).toInt()
// Note: When nCollectors == 0 we resume all queued emitters and we might have resumed more than bufferCapacity,
// if which case we need to coerce the resulting buffer size and adjust newMinCollectorIndex
if (nCollectors == 0 && newBufferSize1 > bufferCapacity) {
newMinCollectorIndex += newBufferSize1 - bufferCapacity // adjust minCollectorIndex, too, to skip items
newBufferSize1 = bufferCapacity
}
// Compute new replay size -> limit to replay the number of items we need, take into account that it can only grow
var newReplayIndex = maxOf(replayIndex, newBufferEndIndex - minOf(replay, newBufferSize1))
// adjustment for synchronous case with cancelled emitter (NO_VALUE)
Expand Down
Expand Up @@ -201,6 +201,26 @@ class SharedFlowScenarioTest : TestBase() {
emitResumes(e3); expectReplayOf(3)
}

@Test
fun testSuspendedConcurrentEmitAndCancelSubscriber() =
testSharedFlow<Int>(MutableSharedFlow(1)) {
val a = subscribe("a");
emitRightNow(0); expectReplayOf(0)
collect(a, 0)
emitRightNow(1); expectReplayOf(1)
val e2 = emitSuspends(2) // suspends until 1 is collected
val e3 = emitSuspends(3) // suspends until 1 is collected, too
cancel(a) // must resume emitters 2 & 3
emitResumes(e2)
emitResumes(e3)
expectReplayOf(3) // but replay size is 1 so only 3 should be kept
// Note: originally, SharedFlow was in a broken state here with 3 elements in the buffer
val b = subscribe("b")
collect(b, 3)
emitRightNow(4); expectReplayOf(4)
collect(b, 4)
}

private fun <T> testSharedFlow(
sharedFlow: MutableSharedFlow<T>,
scenario: suspend ScenarioDsl<T>.() -> Unit
Expand Down
77 changes: 77 additions & 0 deletions kotlinx-coroutines-core/jvm/test/flow/SharedFlowStressTest.kt
@@ -0,0 +1,77 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import org.junit.*
import org.junit.Test
import kotlin.collections.ArrayList
import kotlin.test.*
import kotlin.time.*

@ExperimentalTime
class SharedFlowStressTest : TestBase() {
private val nProducers = 5
private val nConsumers = 3
private val nSeconds = 3 * stressTestMultiplier

private val sf: MutableSharedFlow<Long> = MutableSharedFlow(1)
private val view: SharedFlow<Long> = sf.asSharedFlow()

@get:Rule
val producerDispatcher = ExecutorRule(nProducers)
@get:Rule
val consumerDispatcher = ExecutorRule(nConsumers)

private val totalProduced = atomic(0L)
private val totalConsumed = atomic(0L)

@Test
fun testStress() = runTest {
val jobs = ArrayList<Job>()
jobs += List(nProducers) { producerIndex ->
launch(producerDispatcher) {
var cur = producerIndex.toLong()
while (isActive) {
sf.emit(cur)
totalProduced.incrementAndGet()
cur += nProducers
}
}
}
jobs += List(nConsumers) { consumerIndex ->
launch(consumerDispatcher) {
while (isActive) {
view
.dropWhile { it % nConsumers != consumerIndex.toLong() }
.take(1)
.collect {
check(it % nConsumers == consumerIndex.toLong())
totalConsumed.incrementAndGet()
}
}
}
}
var lastProduced = 0L
var lastConsumed = 0L
for (sec in 1..nSeconds) {
delay(1.seconds)
val produced = totalProduced.value
val consumed = totalConsumed.value
println("$sec sec: produced = $produced; consumed = $consumed")
assertNotEquals(lastProduced, produced)
assertNotEquals(lastConsumed, consumed)
lastProduced = produced
lastConsumed = consumed
}
jobs.forEach { it.cancel() }
jobs.forEach { it.join() }
println("total: produced = ${totalProduced.value}; consumed = ${totalConsumed.value}")
}

private fun showStats(s: String) {
elizarov marked this conversation as resolved.
Show resolved Hide resolved
}
}