Skip to content

Commit

Permalink
Do not propagate cancellation to the upstream in Flow flat* operators (
Browse files Browse the repository at this point in the history
…#2964)

* Do not propagate cancellation to the upstream in Flow flat* operators

Fixes #2964
  • Loading branch information
qwwdfsad committed Oct 20, 2021
1 parent 85b17ce commit 80af499
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 46 deletions.
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Expand Up @@ -133,7 +133,7 @@ internal fun <E> CoroutineScope.produce(
return coroutine
}

internal open class ProducerCoroutine<E>(
private class ProducerCoroutine<E>(
parentContext: CoroutineContext, channel: Channel<E>
) : ChannelCoroutine<E>(parentContext, channel, true, active = true), ProducerScope<E> {
override val isActive: Boolean
Expand Down
Expand Up @@ -51,33 +51,11 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
flowScope { block(this@flow) }
}

internal fun <T> CoroutineScope.flowProduce(
context: CoroutineContext,
capacity: Int = 0,
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
): ReceiveChannel<T> {
val channel = Channel<T>(capacity)
val newContext = newCoroutineContext(context)
val coroutine = FlowProduceCoroutine(newContext, channel)
coroutine.start(CoroutineStart.ATOMIC, coroutine, block)
return coroutine
}

private class FlowCoroutine<T>(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
public override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
}

private class FlowProduceCoroutine<T>(
parentContext: CoroutineContext,
channel: Channel<T>
) : ProducerCoroutine<T>(parentContext, channel) {
public override fun childCancelled(cause: Throwable): Boolean {
override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
Expand Down
6 changes: 3 additions & 3 deletions kotlinx-coroutines-core/common/src/flow/internal/Merge.kt
Expand Up @@ -22,7 +22,7 @@ internal class ChannelFlowTransformLatest<T, R>(

override suspend fun flowCollect(collector: FlowCollector<R>) {
assert { collector is SendingCollector } // So cancellation behaviour is not leaking into the downstream
flowScope {
coroutineScope {
var previousFlow: Job? = null
flow.collect { value ->
previousFlow?.apply {
Expand All @@ -49,7 +49,7 @@ internal class ChannelFlowMerge<T>(
ChannelFlowMerge(flow, concurrency, context, capacity, onBufferOverflow)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
return scope.produce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
Expand Down Expand Up @@ -87,7 +87,7 @@ internal class ChannelLimitedFlowMerge<T>(
ChannelLimitedFlowMerge(flows, context, capacity, onBufferOverflow)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
return scope.produce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
Expand Down
9 changes: 4 additions & 5 deletions kotlinx-coroutines-core/common/src/flow/operators/Merge.kt
Expand Up @@ -61,7 +61,7 @@ public fun <T, R> Flow<T>.flatMapConcat(transform: suspend (value: T) -> Flow<R>
* its concurrent merging so that only one properly configured channel is used for execution of merging logic.
*
* @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected
* at the same time. By default it is equal to [DEFAULT_CONCURRENCY].
* at the same time. By default, it is equal to [DEFAULT_CONCURRENCY].
*/
@FlowPreview
public fun <T, R> Flow<T>.flatMapMerge(
Expand All @@ -71,8 +71,7 @@ public fun <T, R> Flow<T>.flatMapMerge(
map(transform).flattenMerge(concurrency)

/**
* Flattens the given flow of flows into a single flow in a sequentially manner, without interleaving nested flows.
* This method is conceptually identical to `flattenMerge(concurrency = 1)` but has faster implementation.
* Flattens the given flow of flows into a single flow in a sequential manner, without interleaving nested flows.
*
* Inner flows are collected by this operator *sequentially*.
*/
Expand Down Expand Up @@ -119,7 +118,7 @@ public fun <T> merge(vararg flows: Flow<T>): Flow<T> = flows.asIterable().merge(
* Flattens the given flow of flows into a single flow with a [concurrency] limit on the number of
* concurrently collected flows.
*
* If [concurrency] is more than 1, then inner flows are be collected by this operator *concurrently*.
* If [concurrency] is more than 1, then inner flows are collected by this operator *concurrently*.
* With `concurrency == 1` this operator is identical to [flattenConcat].
*
* ### Operator fusion
Expand All @@ -131,7 +130,7 @@ public fun <T> merge(vararg flows: Flow<T>): Flow<T> = flows.asIterable().merge(
* and size of its output buffer can be changed by applying subsequent [buffer] operator.
*
* @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected
* at the same time. By default it is equal to [DEFAULT_CONCURRENCY].
* at the same time. By default, it is equal to [DEFAULT_CONCURRENCY].
*/
@FlowPreview
public fun <T> Flow<Flow<T>>.flattenMerge(concurrency: Int = DEFAULT_CONCURRENCY): Flow<T> {
Expand Down
Expand Up @@ -39,19 +39,14 @@ class FlatMapMergeFastPathTest : FlatMapMergeBaseTest() {

@Test
fun testCancellationExceptionDownstream() = runTest {
val flow = flow {
emit(1)
hang { expect(2) }
}.flatMapMerge {
val flow = flowOf(1, 2, 3).flatMapMerge {
flow {
emit(it)
expect(1)
throw CancellationException("")
}
}.buffer(64)

assertFailsWith<CancellationException>(flow)
finish(3)
assertEquals(listOf(1, 2, 3), flow.toList())
}

@Test
Expand Down
Expand Up @@ -69,19 +69,14 @@ class FlatMapMergeTest : FlatMapMergeBaseTest() {

@Test
fun testCancellationExceptionDownstream() = runTest {
val flow = flow {
emit(1)
hang { expect(2) }
}.flatMapMerge {
val flow = flowOf(1, 2, 3).flatMapMerge {
flow {
emit(it)
expect(1)
throw CancellationException("")
}
}

assertFailsWith<CancellationException>(flow)
finish(3)
assertEquals(listOf(1, 2, 3), flow.toList())
}

@Test
Expand Down
Expand Up @@ -36,4 +36,17 @@ class FlattenConcatTest : FlatMapBaseTest() {
consumer.cancelAndJoin()
finish(2)
}

@Test
fun testCancellation() = runTest {
val flow = flow {
repeat(5) {
emit(flow {
if (it == 2) throw CancellationException("")
emit(1)
})
}
}
assertFailsWith<CancellationException>(flow.flattenConcat())
}
}
58 changes: 58 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt
Expand Up @@ -45,6 +45,64 @@ abstract class MergeTest : TestBase() {
assertEquals(listOf("source"), result)
}

@Test
fun testOneSourceCancelled() = runTest {
val flow = flow {
expect(1)
emit(1)
expect(2)
yield()
throw CancellationException("")
}

val otherFlow = flow {
repeat(5) {
emit(1)
yield()
}

expect(3)
}

val result = listOf(flow, otherFlow).merge().toList()
assertEquals(MutableList(6) { 1 }, result)
finish(4)
}

@Test
fun testOneSourceCancelledNonFused() = runTest {
val flow = flow {
expect(1)
emit(1)
expect(2)
yield()
throw CancellationException("")
}

val otherFlow = flow {
repeat(5) {
emit(1)
yield()
}

expect(3)
}

val result = listOf(flow, otherFlow).nonFuseableMerge().toList()
assertEquals(MutableList(6) { 1 }, result)
finish(4)
}

private fun <T> Iterable<Flow<T>>.nonFuseableMerge(): Flow<T> {
return channelFlow {
forEach { flow ->
launch {
flow.collect { send(it) }
}
}
}
}

@Test
fun testIsolatedContext() = runTest {
val flow = flow {
Expand Down

0 comments on commit 80af499

Please sign in to comment.