diff --git a/kotlinx-coroutines-core/common/src/channels/Produce.kt b/kotlinx-coroutines-core/common/src/channels/Produce.kt index 68fb09a41c..a0c3284240 100644 --- a/kotlinx-coroutines-core/common/src/channels/Produce.kt +++ b/kotlinx-coroutines-core/common/src/channels/Produce.kt @@ -27,7 +27,7 @@ public interface ProducerScope : CoroutineScope, SendChannel { /** * Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel] - * and invokes the given [block] before resuming the coroutine. + * and invokes the given [block] before resuming the coroutine. This suspending function is cancellable. * * Note that when the producer channel is cancelled, this function resumes with a cancellation exception. * Therefore, in case of cancellation, no code after the call to this function will be executed. diff --git a/kotlinx-coroutines-core/common/src/flow/Builders.kt b/kotlinx-coroutines-core/common/src/flow/Builders.kt index 49ad2922e9..95b3da55e2 100644 --- a/kotlinx-coroutines-core/common/src/flow/Builders.kt +++ b/kotlinx-coroutines-core/common/src/flow/Builders.kt @@ -11,9 +11,9 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.channels.Channel.Factory.BUFFERED import kotlinx.coroutines.flow.internal.* -import kotlinx.coroutines.flow.internal.unsafeFlow as flow import kotlin.coroutines.* import kotlin.jvm.* +import kotlinx.coroutines.flow.internal.unsafeFlow as flow /** * Creates a flow from the given suspendable [block]. @@ -259,10 +259,14 @@ public fun channelFlow(@BuilderInference block: suspend ProducerScope.() * * This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used * from any context, e.g. from a callback-based API. - * The resulting flow completes as soon as the code in the [block] and all its children completes. - * Use [awaitClose] as the last statement to keep it running. - * The [awaitClose] argument is called either when a flow consumer cancels the flow collection - * or when a callback-based API invokes [SendChannel.close] manually. + * The resulting flow completes as soon as the code in the [block] completes. + * [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately + * when block completes. + * [awaitClose] argument is called either when a flow consumer cancels the flow collection + * or when a callback-based API invokes [SendChannel.close] manually and is typically used + * to cleanup the resources after the completion, e.g. unregister a callback. + * Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled, + * otherwise the callback may keep running even when the flow collector is already completed. * * A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the * resulting flow to specify a user-defined value and to control what happens when data is produced faster @@ -287,21 +291,20 @@ public fun channelFlow(@BuilderInference block: suspend ProducerScope.() * override fun onCompleted() = channel.close() * } * api.register(callback) - * // Suspend until either onCompleted or external cancellation are invoked + * /* + * * Suspends until either 'onCompleted' from the callback is invoked + * * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's activity was destroyed). + * * In both cases, callback will be properly unregistered. + * */ * awaitClose { api.unregister(callback) } * } * ``` - * - * This function is an alias for [channelFlow], it has a separate name to reflect - * the intent of the usage (integration with a callback-based API) better. */ -@Suppress("NOTHING_TO_INLINE") @ExperimentalCoroutinesApi -public inline fun callbackFlow(@BuilderInference noinline block: suspend ProducerScope.() -> Unit): Flow = - channelFlow(block) +public fun callbackFlow(@BuilderInference block: suspend ProducerScope.() -> Unit): Flow = CallbackFlowBuilder(block) // ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow -private class ChannelFlowBuilder( +private open class ChannelFlowBuilder( private val block: suspend ProducerScope.() -> Unit, context: CoroutineContext = EmptyCoroutineContext, capacity: Int = BUFFERED @@ -315,3 +318,36 @@ private class ChannelFlowBuilder( override fun toString(): String = "block[$block] -> ${super.toString()}" } + +private class CallbackFlowBuilder( + private val block: suspend ProducerScope.() -> Unit, + context: CoroutineContext = EmptyCoroutineContext, + capacity: Int = BUFFERED +) : ChannelFlowBuilder(block, context, capacity) { + + private val collectCallback: suspend (ProducerScope) -> Unit = { + collectTo(it) + /* + * We expect user either call `awaitClose` from within a block (then the channel is closed at this moment) + * or being closed/cancelled externally/manually. Otherwise "user forgot to call + * awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected. + */ + if (it.isActive && !it.isClosedForSend) { + throw IllegalStateException( + """ + 'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block. + Otherwise, a callback/listener may leak in case of cancellation external cancellation (e.g. by 'take(1)' or destroyed activity). + For a more detailed explanation, please refer to callbackFlow KDoc. + """.trimIndent()) + } + } + + override fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel = + scope.broadcast(context, produceCapacity, start, block = collectCallback) + + override fun produceImpl(scope: CoroutineScope): ReceiveChannel = + scope.produce(context, produceCapacity, block = collectCallback) + + override fun create(context: CoroutineContext, capacity: Int): ChannelFlow = + CallbackFlowBuilder(block, context, capacity) +} diff --git a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt index 4711b88418..40df591632 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt @@ -27,6 +27,14 @@ public abstract class ChannelFlow( // buffer capacity between upstream and downstream context @JvmField val capacity: Int ) : Flow { + + // shared code to create a suspend lambda from collectTo function in one place + internal val collectToFun: suspend (ProducerScope) -> Unit + get() = { collectTo(it) } + + protected val produceCapacity: Int + get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity + public fun update( context: CoroutineContext = EmptyCoroutineContext, capacity: Int = Channel.OPTIONAL_CHANNEL @@ -57,13 +65,6 @@ public abstract class ChannelFlow( protected abstract suspend fun collectTo(scope: ProducerScope) - // shared code to create a suspend lambda from collectTo function in one place - internal val collectToFun: suspend (ProducerScope) -> Unit - get() = { collectTo(it) } - - private val produceCapacity: Int - get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity - open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel = scope.broadcast(context, produceCapacity, start, block = collectToFun) @@ -75,11 +76,11 @@ public abstract class ChannelFlow( collector.emitAll(produceImpl(this)) } + open fun additionalToStringProps() = "" + // debug toString override fun toString(): String = "$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]" - - open fun additionalToStringProps() = "" } // ChannelFlow implementation that operates on another flow before it @@ -161,7 +162,7 @@ private suspend fun withContextUndispatched( countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed block: suspend (V) -> T, value: V ): T = - suspendCoroutineUninterceptedOrReturn sc@{ uCont -> + suspendCoroutineUninterceptedOrReturn { uCont -> withCoroutineContext(newContext, countOrElement) { block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) { uCont.resumeWith(it) diff --git a/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt b/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt index bf85c74f64..885f1d6c8f 100644 --- a/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt +++ b/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt @@ -5,6 +5,7 @@ package kotlinx.coroutines.channels import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* import kotlin.coroutines.* import kotlin.test.* @@ -143,9 +144,20 @@ class ProduceTest : TestBase() { @Test fun testAwaitIllegalState() = runTest { - val channel = produce { } - @Suppress("RemoveExplicitTypeArguments") // KT-31525 + val channel = produce { } assertFailsWith { (channel as ProducerScope<*>).awaitClose() } + callbackFlow { + expect(1) + launch { + expect(2) + assertFailsWith { + awaitClose { expectUnreached() } + expectUnreached() + } + } + close() + }.collect() + finish(3) } private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply { diff --git a/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt b/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt index 32c2afc65b..b115150a0b 100644 --- a/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt @@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() { finish(6) } + + @Test + fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) { + val outerScope = this + val flow = channelFlow { + // ~ callback-based API, no children + outerScope.launch(Job()) { + expect(2) + send(1) + expectUnreached() + } + expect(1) + } + assertEquals(emptyList(), flow.toList()) + finish(3) + } + + @Test + fun testNotClosedPrematurely() = runTest { + val outerScope = this + val flow = channelFlow { + // ~ callback-based API + outerScope.launch(Job()) { + expect(2) + send(1) + close() + } + expect(1) + awaitClose() + } + + assertEquals(listOf(1), flow.toList()) + finish(3) + } } diff --git a/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt b/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt index a6b5340555..cfbf242c35 100644 --- a/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt @@ -12,25 +12,35 @@ import kotlin.test.* class FlowCallbackTest : TestBase() { @Test - fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) { + fun testClosedPrematurely() = runTest { val outerScope = this - val flow = channelFlow { + val flow = callbackFlow { // ~ callback-based API outerScope.launch(Job()) { expect(2) - send(1) - expectUnreached() + try { + send(1) + expectUnreached() + } catch (e: IllegalStateException) { + expect(3) + assertTrue(e.message!!.contains("awaitClose")) + } } expect(1) } - assertEquals(emptyList(), flow.toList()) - finish(3) + try { + flow.collect() + } catch (e: IllegalStateException) { + expect(4) + assertTrue(e.message!!.contains("awaitClose")) + } + finish(5) } @Test fun testNotClosedPrematurely() = runTest { val outerScope = this - val flow = channelFlow { + val flow = callbackFlow { // ~ callback-based API outerScope.launch(Job()) { expect(2) diff --git a/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt b/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt index f71040343d..d207113b74 100644 --- a/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt +++ b/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt @@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() { runCatching { it.offer(++i) } } - val flow = channelFlow { + val flow = callbackFlow { api.start(channel) awaitClose { api.stop()