Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not propagate cancellation to the upstream in Flow flat* operators #2964

Merged
merged 2 commits into from Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Expand Up @@ -137,7 +137,7 @@ internal fun <E> CoroutineScope.produce(
return coroutine
}

internal open class ProducerCoroutine<E>(
private class ProducerCoroutine<E>(
parentContext: CoroutineContext, channel: Channel<E>
) : ChannelCoroutine<E>(parentContext, channel, true, active = true), ProducerScope<E> {
override val isActive: Boolean
Expand Down
Expand Up @@ -51,33 +51,11 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
flowScope { block(this@flow) }
}

internal fun <T> CoroutineScope.flowProduce(
context: CoroutineContext,
capacity: Int = 0,
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
): ReceiveChannel<T> {
val channel = Channel<T>(capacity)
val newContext = newCoroutineContext(context)
val coroutine = FlowProduceCoroutine(newContext, channel)
coroutine.start(CoroutineStart.ATOMIC, coroutine, block)
return coroutine
}

private class FlowCoroutine<T>(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
public override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
}

private class FlowProduceCoroutine<T>(
parentContext: CoroutineContext,
channel: Channel<T>
) : ProducerCoroutine<T>(parentContext, channel) {
public override fun childCancelled(cause: Throwable): Boolean {
override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
Expand Down
6 changes: 3 additions & 3 deletions kotlinx-coroutines-core/common/src/flow/internal/Merge.kt
Expand Up @@ -22,7 +22,7 @@ internal class ChannelFlowTransformLatest<T, R>(

override suspend fun flowCollect(collector: FlowCollector<R>) {
assert { collector is SendingCollector } // So cancellation behaviour is not leaking into the downstream
flowScope {
coroutineScope {
var previousFlow: Job? = null
flow.collect { value ->
previousFlow?.apply {
Expand All @@ -49,7 +49,7 @@ internal class ChannelFlowMerge<T>(
ChannelFlowMerge(flow, concurrency, context, capacity, onBufferOverflow)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
return scope.produce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
Expand Down Expand Up @@ -87,7 +87,7 @@ internal class ChannelLimitedFlowMerge<T>(
ChannelLimitedFlowMerge(flows, context, capacity, onBufferOverflow)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
return scope.produce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/flow/operators/Merge.kt
Expand Up @@ -61,7 +61,7 @@ public fun <T, R> Flow<T>.flatMapConcat(transform: suspend (value: T) -> Flow<R>
* its concurrent merging so that only one properly configured channel is used for execution of merging logic.
*
* @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected
* at the same time. By default it is equal to [DEFAULT_CONCURRENCY].
* at the same time. By default, it is equal to [DEFAULT_CONCURRENCY].
*/
@FlowPreview
public fun <T, R> Flow<T>.flatMapMerge(
Expand Down
Expand Up @@ -39,19 +39,14 @@ class FlatMapMergeFastPathTest : FlatMapMergeBaseTest() {

@Test
fun testCancellationExceptionDownstream() = runTest {
val flow = flow {
emit(1)
hang { expect(2) }
}.flatMapMerge {
val flow = flowOf(1, 2, 3).flatMapMerge {
flow {
emit(it)
expect(1)
throw CancellationException("")
}
}.buffer(64)

assertFailsWith<CancellationException>(flow)
finish(3)
assertEquals(listOf(1, 2, 3), flow.toList())
dkhalanskyjb marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
Expand Down
Expand Up @@ -69,19 +69,14 @@ class FlatMapMergeTest : FlatMapMergeBaseTest() {

@Test
fun testCancellationExceptionDownstream() = runTest {
val flow = flow {
emit(1)
hang { expect(2) }
}.flatMapMerge {
val flow = flowOf(1, 2, 3).flatMapMerge {
flow {
emit(it)
expect(1)
throw CancellationException("")
}
}

assertFailsWith<CancellationException>(flow)
finish(3)
assertEquals(listOf(1, 2, 3), flow.toList())
}

@Test
Expand Down
Expand Up @@ -36,4 +36,17 @@ class FlattenConcatTest : FlatMapBaseTest() {
consumer.cancelAndJoin()
finish(2)
}

@Test
fun testCancellation() = runTest {
val flow = flow {
repeat(5) {
emit(flow {
if (it == 2) throw CancellationException("")
emit(1)
})
}
}
assertFailsWith<CancellationException>(flow.flattenConcat())
}
}
58 changes: 58 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt
Expand Up @@ -45,6 +45,64 @@ abstract class MergeTest : TestBase() {
assertEquals(listOf("source"), result)
}

@Test
fun testOneSourceCancelled() = runTest {
val flow = flow {
expect(1)
emit(1)
expect(2)
yield()
throw CancellationException("")
}

val otherFlow = flow {
repeat(5) {
emit(1)
yield()
}

expect(3)
}

val result = listOf(flow, otherFlow).merge().toList()
assertEquals(MutableList(6) { 1 }, result)
finish(4)
}

@Test
fun testOneSourceCancelledNonFused() = runTest {
val flow = flow {
expect(1)
emit(1)
expect(2)
yield()
throw CancellationException("")
}

val otherFlow = flow {
repeat(5) {
emit(1)
yield()
}

expect(3)
}

val result = listOf(flow, otherFlow).nonFuseableMerge().toList()
assertEquals(MutableList(6) { 1 }, result)
finish(4)
}

private fun <T> Iterable<Flow<T>>.nonFuseableMerge(): Flow<T> {
return channelFlow {
forEach { flow ->
launch {
flow.collect { send(it) }
}
}
}
}

@Test
fun testIsolatedContext() = runTest {
val flow = flow {
Expand Down