Skip to content

Commit

Permalink
Non-conflating subscription count in SharedFlow and StateFlow (Kotlin…
Browse files Browse the repository at this point in the history
…#2872)

* Non-conflating subscription count in SharedFlow and StateFlow

Sharing strategies are too sensitive to conflation around extrema and may miss the necessity to start or not to stop the sharing. For more particular examples see Kotlin#2863 and Kotlin#2488

Fixes Kotlin#2488
Fixes Kotlin#2863
Fixes Kotlin#2871
  • Loading branch information
qwwdfsad authored and pablobaxter committed Sep 14, 2022
1 parent 6f00cdf commit 905b0cb
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 21 deletions.
13 changes: 11 additions & 2 deletions kotlinx-coroutines-core/common/src/flow/SharedFlow.kt
Expand Up @@ -198,6 +198,8 @@ public interface MutableSharedFlow<T> : SharedFlow<T>, FlowCollector<T> {
* }
* .launchIn(scope) // launch it
* ```
*
* Implementation note: the resulting flow **does not** conflate subscription count.
*/
public val subscriptionCount: StateFlow<Int>

Expand Down Expand Up @@ -253,7 +255,7 @@ public fun <T> MutableSharedFlow(

// ------------------------------------ Implementation ------------------------------------

private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
internal class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
@JvmField
var index = -1L // current "to-be-emitted" index, -1 means the slot is free now

Expand All @@ -275,7 +277,7 @@ private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
}
}

private class SharedFlowImpl<T>(
internal open class SharedFlowImpl<T>(
private val replay: Int,
private val bufferCapacity: Int,
private val onBufferOverflow: BufferOverflow
Expand Down Expand Up @@ -334,6 +336,13 @@ private class SharedFlowImpl<T>(
result
}

/*
* A tweak for SubscriptionCountStateFlow to get the latest value.
*/
@Suppress("UNCHECKED_CAST")
protected val lastReplayedLocked: T
get() = buffer!!.getBufferAt(replayIndex + replaySize - 1) as T

@Suppress("UNCHECKED_CAST")
override suspend fun collect(collector: FlowCollector<T>) {
val slot = allocateSlot()
Expand Down
4 changes: 0 additions & 4 deletions kotlinx-coroutines-core/common/src/flow/StateFlow.kt
Expand Up @@ -415,10 +415,6 @@ private class StateFlowImpl<T>(
fuseStateFlow(context, capacity, onBufferOverflow)
}

internal fun MutableStateFlow<Int>.increment(delta: Int) {
update { it + delta }
}

internal fun <T> StateFlow<T>.fuseStateFlow(
context: CoroutineContext,
capacity: Int,
Expand Down
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
Expand All @@ -26,12 +27,12 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
protected var nCollectors = 0 // number of allocated (!free) slots
private set
private var nextIndex = 0 // oracle for the next free slot index
private var _subscriptionCount: MutableStateFlow<Int>? = null // init on first need
private var _subscriptionCount: SubscriptionCountStateFlow? = null // init on first need

val subscriptionCount: StateFlow<Int>
get() = synchronized(this) {
// allocate under lock in sync with nCollectors variable
_subscriptionCount ?: MutableStateFlow(nCollectors).also {
_subscriptionCount ?: SubscriptionCountStateFlow(nCollectors).also {
_subscriptionCount = it
}
}
Expand All @@ -43,7 +44,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
@Suppress("UNCHECKED_CAST")
protected fun allocateSlot(): S {
// Actually create slot under lock
var subscriptionCount: MutableStateFlow<Int>? = null
var subscriptionCount: SubscriptionCountStateFlow? = null
val slot = synchronized(this) {
val slots = when (val curSlots = slots) {
null -> createSlotArray(2).also { slots = it }
Expand Down Expand Up @@ -74,7 +75,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
@Suppress("UNCHECKED_CAST")
protected fun freeSlot(slot: S) {
// Release slot under lock
var subscriptionCount: MutableStateFlow<Int>? = null
var subscriptionCount: SubscriptionCountStateFlow? = null
val resumes = synchronized(this) {
nCollectors--
subscriptionCount = _subscriptionCount // retrieve under lock if initialized
Expand All @@ -83,10 +84,10 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
(slot as AbstractSharedFlowSlot<Any>).freeLocked(this)
}
/*
Resume suspended coroutines.
This can happens when the subscriber that was freed was a slow one and was holding up buffer.
When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
*/
* Resume suspended coroutines.
* This can happen when the subscriber that was freed was a slow one and was holding up buffer.
* When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
*/
for (cont in resumes) cont?.resume(Unit)
// decrement subscription count
subscriptionCount?.increment(-1)
Expand All @@ -99,3 +100,35 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
}
}
}

/**
* [StateFlow] that represents the number of subscriptions.
*
* It is exposed as a regular [StateFlow] in our public API, but it is implemented as [SharedFlow] undercover to
* avoid conflations of consecutive updates because the subscription count is very sensitive to it.
*
* The importance of non-conflating can be demonstrated with the following example:
* ```
* val shared = flowOf(239).stateIn(this, SharingStarted.Lazily, 42) // stateIn for the sake of the initial value
* println(shared.first())
* yield()
* println(shared.first())
* ```
* If the flow is shared within the same dispatcher (e.g. Main) or with a slow/throttled one,
* the `SharingStarted.Lazily` will never be able to start the source: `first` sees the initial value and immediately
* unsubscribes, leaving the asynchronous `SharingStarted` with conflated zero.
*
* To avoid that (especially in a more complex scenarios), we do not conflate subscription updates.
*/
private class SubscriptionCountStateFlow(initialValue: Int) : StateFlow<Int>,
SharedFlowImpl<Int>(1, Int.MAX_VALUE, BufferOverflow.DROP_OLDEST)
{
init { tryEmit(initialValue) }

override val value: Int
get() = synchronized(this) { lastReplayedLocked }

fun increment(delta: Int) = synchronized(this) {
tryEmit(lastReplayedLocked + delta)
}
}
13 changes: 11 additions & 2 deletions kotlinx-coroutines-core/common/src/flow/operators/Share.kt
Expand Up @@ -197,8 +197,16 @@ private fun <T> CoroutineScope.launchSharing(
shared: MutableSharedFlow<T>,
started: SharingStarted,
initialValue: T
): Job =
launch(context) { // the single coroutine to rule the sharing
): Job {
/*
* Conditional start: in the case when sharing and subscribing happens in the same dispatcher, we want to
* have the following invariants preserved:
* * Delayed sharing strategies have a chance to immediately observe consecutive subscriptions.
* E.g. in the cases like `flow.shareIn(...); flow.take(1)` we want sharing strategy to see the initial subscription
* * Eager sharing does not start immediately, so the subscribers have actual chance to subscribe _prior_ to sharing.
*/
val start = if (started == SharingStarted.Eagerly) CoroutineStart.DEFAULT else CoroutineStart.UNDISPATCHED
return launch(context, start = start) { // the single coroutine to rule the sharing
// Optimize common built-in started strategies
when {
started === SharingStarted.Eagerly -> {
Expand Down Expand Up @@ -230,6 +238,7 @@ private fun <T> CoroutineScope.launchSharing(
}
}
}
}

// -------------------------------- stateIn --------------------------------

Expand Down
Expand Up @@ -21,7 +21,7 @@ class ShareInConflationTest : TestBase() {
op: suspend Flow<Int>.(CoroutineScope) -> Flow<Int>
) = runTest {
expect(1)
// emit all and conflate, then should collect bufferCapacity latest ones
// emit all and conflate, then should collect bufferCapacity the latest ones
val done = Job()
flow {
repeat(n) { i ->
Expand Down Expand Up @@ -159,4 +159,4 @@ class ShareInConflationTest : TestBase() {
checkConflation(1, BufferOverflow.DROP_LATEST) {
buffer(23).buffer(onBufferOverflow = BufferOverflow.DROP_LATEST).shareIn(it, SharingStarted.Eagerly, 0)
}
}
}
26 changes: 26 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt
Expand Up @@ -210,4 +210,30 @@ class ShareInTest : TestBase() {
stop()
}
}

@Test
fun testShouldStart() = runTest {
val flow = flow {
expect(2)
emit(1)
expect(3)
}.shareIn(this, SharingStarted.Lazily)

expect(1)
flow.onSubscription { throw CancellationException("") }
.catch { e -> assertTrue { e is CancellationException } }
.collect()
yield()
finish(4)
}

@Test
fun testShouldStartScalar() = runTest {
val j = Job()
val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42)
assertEquals(42, shared.first())
yield()
assertEquals(239, shared.first())
j.cancel()
}
}
20 changes: 20 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt
Expand Up @@ -798,4 +798,24 @@ class SharedFlowTest : TestBase() {
job.join()
finish(5)
}

@Test
fun testSubscriptionCount() = runTest {
val flow = MutableSharedFlow<Int>()
fun startSubscriber() = launch(start = CoroutineStart.UNDISPATCHED) { flow.collect() }

assertEquals(0, flow.subscriptionCount.first())

val j1 = startSubscriber()
assertEquals(1, flow.subscriptionCount.first())

val j2 = startSubscriber()
assertEquals(2, flow.subscriptionCount.first())

j1.cancelAndJoin()
assertEquals(1, flow.subscriptionCount.first())

j2.cancelAndJoin()
assertEquals(0, flow.subscriptionCount.first())
}
}
Expand Up @@ -40,5 +40,38 @@ class SharingStartedWhileSubscribedTest : TestBase() {
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = 7000), SharingStarted.WhileSubscribed(replayExpiration = 7.seconds))
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = Long.MAX_VALUE), SharingStarted.WhileSubscribed(replayExpiration = Duration.INFINITE))
}
}

@Test
fun testShouldRestart() = runTest {
var started = 0
val flow = flow {
expect(1 + ++started)
emit(1)
hang { }
}.shareIn(this, SharingStarted.WhileSubscribed(100 /* ms */))

expect(1)
flow.first()
delay(200)
flow.first()
finish(4)
coroutineContext.job.cancelChildren()
}

@Test
fun testImmediateUnsubscribe() = runTest {
val flow = flow {
expect(2)
emit(1)
hang { finish(4) }
}.shareIn(this, SharingStarted.WhileSubscribed(400, 0 /* ms */), 1)

expect(1)
repeat(5) {
flow.first()
delay(100)
}
expect(3)
coroutineContext.job.cancelChildren()
}
}
8 changes: 6 additions & 2 deletions kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt
Expand Up @@ -189,5 +189,9 @@ class SharingStressTest : TestBase() {
var count = 0L
}

private fun log(msg: String) = println("${testStarted.elapsedNow().toLongMilliseconds()} ms: $msg")
}
private fun log(msg: String) = println("${testStarted.elapsedNow().inWholeMilliseconds} ms: $msg")

private fun MutableStateFlow<Int>.increment(delta: Int) {
update { it + delta }
}
}

0 comments on commit 905b0cb

Please sign in to comment.