diff --git a/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt b/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt index d79e203464..e59e9c853f 100644 --- a/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/SharedFlow.kt @@ -198,6 +198,8 @@ public interface MutableSharedFlow : SharedFlow, FlowCollector { * } * .launchIn(scope) // launch it * ``` + * + * Implementation note: the resulting flow **does not** conflate subscription count. */ public val subscriptionCount: StateFlow @@ -253,7 +255,7 @@ public fun MutableSharedFlow( // ------------------------------------ Implementation ------------------------------------ -private class SharedFlowSlot : AbstractSharedFlowSlot>() { +internal class SharedFlowSlot : AbstractSharedFlowSlot>() { @JvmField var index = -1L // current "to-be-emitted" index, -1 means the slot is free now @@ -275,7 +277,7 @@ private class SharedFlowSlot : AbstractSharedFlowSlot>() { } } -private class SharedFlowImpl( +internal class SharedFlowImpl( private val replay: Int, private val bufferCapacity: Int, private val onBufferOverflow: BufferOverflow @@ -334,6 +336,13 @@ private class SharedFlowImpl( result } + /* + * A tweak for SubscriptionCountStateFlow to get the latest value. + */ + @Suppress("UNCHECKED_CAST") + val lastReplayedLocked: T + get() = buffer!!.getBufferAt(replayIndex + replaySize - 1) as T + @Suppress("UNCHECKED_CAST") override suspend fun collect(collector: FlowCollector) { val slot = allocateSlot() diff --git a/kotlinx-coroutines-core/common/src/flow/StateFlow.kt b/kotlinx-coroutines-core/common/src/flow/StateFlow.kt index 9e82e78771..53770dc91e 100644 --- a/kotlinx-coroutines-core/common/src/flow/StateFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/StateFlow.kt @@ -415,10 +415,6 @@ private class StateFlowImpl( fuseStateFlow(context, capacity, onBufferOverflow) } -internal fun MutableStateFlow.increment(delta: Int) { - update { it + delta } -} - internal fun StateFlow.fuseStateFlow( context: CoroutineContext, capacity: Int, diff --git a/kotlinx-coroutines-core/common/src/flow/internal/AbstractSharedFlow.kt b/kotlinx-coroutines-core/common/src/flow/internal/AbstractSharedFlow.kt index 7114cc08d3..48bacd5d67 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/AbstractSharedFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/AbstractSharedFlow.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines.flow.internal +import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* import kotlinx.coroutines.internal.* import kotlin.coroutines.* @@ -26,12 +27,12 @@ internal abstract class AbstractSharedFlow> : Sync protected var nCollectors = 0 // number of allocated (!free) slots private set private var nextIndex = 0 // oracle for the next free slot index - private var _subscriptionCount: MutableStateFlow? = null // init on first need + private var _subscriptionCount: SubscriptionCountStateFlow? = null // init on first need val subscriptionCount: StateFlow get() = synchronized(this) { // allocate under lock in sync with nCollectors variable - _subscriptionCount ?: MutableStateFlow(nCollectors).also { + _subscriptionCount ?: SubscriptionCountStateFlow(nCollectors).also { _subscriptionCount = it } } @@ -43,7 +44,7 @@ internal abstract class AbstractSharedFlow> : Sync @Suppress("UNCHECKED_CAST") protected fun allocateSlot(): S { // Actually create slot under lock - var subscriptionCount: MutableStateFlow? = null + var subscriptionCount: SubscriptionCountStateFlow? = null val slot = synchronized(this) { val slots = when (val curSlots = slots) { null -> createSlotArray(2).also { slots = it } @@ -74,7 +75,7 @@ internal abstract class AbstractSharedFlow> : Sync @Suppress("UNCHECKED_CAST") protected fun freeSlot(slot: S) { // Release slot under lock - var subscriptionCount: MutableStateFlow? = null + var subscriptionCount: SubscriptionCountStateFlow? = null val resumes = synchronized(this) { nCollectors-- subscriptionCount = _subscriptionCount // retrieve under lock if initialized @@ -83,10 +84,10 @@ internal abstract class AbstractSharedFlow> : Sync (slot as AbstractSharedFlowSlot).freeLocked(this) } /* - Resume suspended coroutines. - This can happens when the subscriber that was freed was a slow one and was holding up buffer. - When this subscriber was freed, previously queued emitted can now wake up and are resumed here. - */ + * Resume suspended coroutines. + * This can happen when the subscriber that was freed was a slow one and was holding up buffer. + * When this subscriber was freed, previously queued emitted can now wake up and are resumed here. + */ for (cont in resumes) cont?.resume(Unit) // decrement subscription count subscriptionCount?.increment(-1) @@ -99,3 +100,43 @@ internal abstract class AbstractSharedFlow> : Sync } } } + +/** + * [StateFlow] that represents the number of subscriptions. + * + * It is exposed as a regular [StateFlow] in our public API, but it is implemented as [SharedFlow] undercover to + * avoid conflations of consecutive updates because the subscription count is very sensitive to it. + * + * The importance of non-conflating can be demonstrated with the following example: + * ``` + * val shared = flowOf(239).stateIn(this, SharingStarted.Lazily, 42) // stateIn for the sake of the initial value + * println(shared.first()) + * yield() + * println(shared.first()) + * ``` + * If the flow is shared within the same dispatcher (e.g. Main) or with a slow/throttled one, + * the `SharingStarted.Lazily` will never be able to start the source: `first` sees the initial value and immediately + * unsubscribes, leaving the asynchronous `SharingStarted` with conflated zero. + * + * To avoid that (especially in a more complex scenarios), we do not conflate subscription updates. + */ +private class SubscriptionCountStateFlow(initialValue: Int) : StateFlow { + private val sharedFlow = SharedFlowImpl(1, Int.MAX_VALUE, BufferOverflow.DROP_OLDEST) + .also { it.tryEmit(initialValue) } + + override val replayCache: List + get() = sharedFlow.replayCache + + override val value: Int + get() = synchronized(sharedFlow) { + sharedFlow.lastReplayedLocked + } + + fun increment(delta: Int) = synchronized(sharedFlow) { + sharedFlow.tryEmit(sharedFlow.lastReplayedLocked + delta) + } + + override suspend fun collect(collector: FlowCollector) { + sharedFlow.collect(collector) + } +} diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Share.kt b/kotlinx-coroutines-core/common/src/flow/operators/Share.kt index 4fa74d8e50..2b690e3c04 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Share.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Share.kt @@ -197,8 +197,16 @@ private fun CoroutineScope.launchSharing( shared: MutableSharedFlow, started: SharingStarted, initialValue: T -): Job = - launch(context) { // the single coroutine to rule the sharing +): Job { + /* + * Conditional start: in the case when sharing and subscribing happens in the same dispatcher, we want to + * have the following invariants preserved: + * * Delayed sharing strategies have a chance to immediately observe consecutive subscriptions. + * E.g. in the cases like `flow.shareIn(...); flow.take(1)` we want sharing strategy to see the initial subscription + * * Eager sharing does not start immediately, so the subscribers have actual chance to subscribe _prior_ to sharing. + */ + val start = if (started == SharingStarted.Eagerly) CoroutineStart.DEFAULT else CoroutineStart.UNDISPATCHED + return launch(context, start = start) { // the single coroutine to rule the sharing // Optimize common built-in started strategies when { started === SharingStarted.Eagerly -> { @@ -230,6 +238,7 @@ private fun CoroutineScope.launchSharing( } } } +} // -------------------------------- stateIn -------------------------------- diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/ShareInConflationTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/ShareInConflationTest.kt index 0528e97e7d..c19d52367b 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/ShareInConflationTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/ShareInConflationTest.kt @@ -21,7 +21,7 @@ class ShareInConflationTest : TestBase() { op: suspend Flow.(CoroutineScope) -> Flow ) = runTest { expect(1) - // emit all and conflate, then should collect bufferCapacity latest ones + // emit all and conflate, then should collect bufferCapacity the latest ones val done = Job() flow { repeat(n) { i -> @@ -159,4 +159,4 @@ class ShareInConflationTest : TestBase() { checkConflation(1, BufferOverflow.DROP_LATEST) { buffer(23).buffer(onBufferOverflow = BufferOverflow.DROP_LATEST).shareIn(it, SharingStarted.Eagerly, 0) } -} \ No newline at end of file +} diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt index db69e2bc06..cf83a50b0f 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt @@ -210,4 +210,30 @@ class ShareInTest : TestBase() { stop() } } + + @Test + fun testShouldStart() = runTest { + val flow = flow { + expect(2) + emit(1) + expect(3) + }.shareIn(this, SharingStarted.Lazily) + + expect(1) + flow.onSubscription { throw CancellationException("") } + .catch { e -> assertTrue { e is CancellationException } } + .collect() + yield() + finish(4) + } + + @Test + fun testShouldStartScalar() = runTest { + val j = Job() + val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42) + assertEquals(42, shared.first()) + yield() + assertEquals(239, shared.first()) + j.cancel() + } } diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt index 6e18b38f55..98e04f00e8 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt @@ -798,4 +798,24 @@ class SharedFlowTest : TestBase() { job.join() finish(5) } + + @Test + fun testSubscriptionCount() = runTest { + val flow = MutableSharedFlow() + fun startSubscriber() = launch(start = CoroutineStart.UNDISPATCHED) { flow.collect() } + + assertEquals(0, flow.subscriptionCount.first()) + + val j1 = startSubscriber() + assertEquals(1, flow.subscriptionCount.first()) + + val j2 = startSubscriber() + assertEquals(2, flow.subscriptionCount.first()) + + j1.cancelAndJoin() + assertEquals(1, flow.subscriptionCount.first()) + + j2.cancelAndJoin() + assertEquals(0, flow.subscriptionCount.first()) + } } diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/SharingStartedWhileSubscribedTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/SharingStartedWhileSubscribedTest.kt index 516bb2e291..b3a3400389 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/SharingStartedWhileSubscribedTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/SharingStartedWhileSubscribedTest.kt @@ -40,5 +40,38 @@ class SharingStartedWhileSubscribedTest : TestBase() { assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = 7000), SharingStarted.WhileSubscribed(replayExpiration = 7.seconds)) assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = Long.MAX_VALUE), SharingStarted.WhileSubscribed(replayExpiration = Duration.INFINITE)) } -} + @Test + fun testShouldRestart() = runTest { + var started = 0 + val flow = flow { + expect(1 + ++started) + emit(1) + hang { } + }.shareIn(this, SharingStarted.WhileSubscribed(100 /* ms */)) + + expect(1) + flow.first() + delay(200) + flow.first() + finish(4) + coroutineContext.job.cancelChildren() + } + + @Test + fun testImmediateUnsubscribe() = runTest { + val flow = flow { + expect(2) + emit(1) + hang { finish(4) } + }.shareIn(this, SharingStarted.WhileSubscribed(400, 0 /* ms */), 1) + + expect(1) + repeat(5) { + flow.first() + delay(100) + } + expect(3) + coroutineContext.job.cancelChildren() + } +} diff --git a/kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt b/kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt index 7d346bdc33..25c0c98314 100644 --- a/kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt @@ -189,5 +189,9 @@ class SharingStressTest : TestBase() { var count = 0L } - private fun log(msg: String) = println("${testStarted.elapsedNow().toLongMilliseconds()} ms: $msg") -} \ No newline at end of file + private fun log(msg: String) = println("${testStarted.elapsedNow().inWholeMilliseconds} ms: $msg") + + private fun MutableStateFlow.increment(delta: Int) { + update { it + delta } + } +}