diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 6df47e179f..a910db9a56 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -987,6 +987,7 @@ public abstract class kotlinx/coroutines/flow/internal/ChannelFlow : kotlinx/cor public fun collect (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected abstract fun collectTo (Lkotlinx/coroutines/channels/ProducerScope;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected abstract fun create (Lkotlin/coroutines/CoroutineContext;I)Lkotlinx/coroutines/flow/internal/ChannelFlow; + protected final fun getProduceCapacity ()I public fun produceImpl (Lkotlinx/coroutines/CoroutineScope;)Lkotlinx/coroutines/channels/ReceiveChannel; public fun toString ()Ljava/lang/String; public final fun update (Lkotlin/coroutines/CoroutineContext;I)Lkotlinx/coroutines/flow/internal/ChannelFlow; diff --git a/kotlinx-coroutines-core/common/src/channels/Channel.kt b/kotlinx-coroutines-core/common/src/channels/Channel.kt index 07e05f07d9..27334fee72 100644 --- a/kotlinx-coroutines-core/common/src/channels/Channel.kt +++ b/kotlinx-coroutines-core/common/src/channels/Channel.kt @@ -586,4 +586,4 @@ public class ClosedSendChannelException(message: String?) : IllegalStateExceptio * * This exception is a subclass of [NoSuchElementException] to be consistent with plain collections. */ -public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message) \ No newline at end of file +public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message) diff --git a/kotlinx-coroutines-core/common/src/channels/Produce.kt b/kotlinx-coroutines-core/common/src/channels/Produce.kt index 68fb09a41c..a0c3284240 100644 --- a/kotlinx-coroutines-core/common/src/channels/Produce.kt +++ b/kotlinx-coroutines-core/common/src/channels/Produce.kt @@ -27,7 +27,7 @@ public interface ProducerScope : CoroutineScope, SendChannel { /** * Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel] - * and invokes the given [block] before resuming the coroutine. + * and invokes the given [block] before resuming the coroutine. This suspending function is cancellable. * * Note that when the producer channel is cancelled, this function resumes with a cancellation exception. * Therefore, in case of cancellation, no code after the call to this function will be executed. diff --git a/kotlinx-coroutines-core/common/src/flow/Builders.kt b/kotlinx-coroutines-core/common/src/flow/Builders.kt index 49ad2922e9..af5e9dc133 100644 --- a/kotlinx-coroutines-core/common/src/flow/Builders.kt +++ b/kotlinx-coroutines-core/common/src/flow/Builders.kt @@ -11,9 +11,9 @@ import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.channels.Channel.Factory.BUFFERED import kotlinx.coroutines.flow.internal.* -import kotlinx.coroutines.flow.internal.unsafeFlow as flow import kotlin.coroutines.* import kotlin.jvm.* +import kotlinx.coroutines.flow.internal.unsafeFlow as flow /** * Creates a flow from the given suspendable [block]. @@ -259,10 +259,16 @@ public fun channelFlow(@BuilderInference block: suspend ProducerScope.() * * This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used * from any context, e.g. from a callback-based API. - * The resulting flow completes as soon as the code in the [block] and all its children completes. - * Use [awaitClose] as the last statement to keep it running. - * The [awaitClose] argument is called either when a flow consumer cancels the flow collection - * or when a callback-based API invokes [SendChannel.close] manually. + * The resulting flow completes as soon as the code in the [block] completes. + * [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately + * when block completes. + * [awaitClose] argument is called either when a flow consumer cancels the flow collection + * or when a callback-based API invokes [SendChannel.close] manually and is typically used + * to cleanup the resources after the completion, e.g. unregister a callback. + * Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled, + * otherwise the callback may keep running even when the flow collector is already completed. + * To avoid such leaks, this method throws [IllegalStateException] if block returns, but the channel + * is not closed yet. * * A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the * resulting flow to specify a user-defined value and to control what happens when data is produced faster @@ -277,9 +283,13 @@ public fun channelFlow(@BuilderInference block: suspend ProducerScope.() * fun flowFrom(api: CallbackBasedApi): Flow = callbackFlow { * val callback = object : Callback { // implementation of some callback interface * override fun onNextValue(value: T) { - * // Note: offer drops value when buffer is full - * // Use either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill - * offer(value) + * // To avoid blocking you can configure channel capacity using + * // either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill + * try { + * sendBlocking(value) + * } catch (e: Exception) { + * // Handle exception from the channel: failure in flow or premature closing + * } * } * override fun onApiError(cause: Throwable) { * cancel(CancellationException("API Error", cause)) @@ -287,21 +297,20 @@ public fun channelFlow(@BuilderInference block: suspend ProducerScope.() * override fun onCompleted() = channel.close() * } * api.register(callback) - * // Suspend until either onCompleted or external cancellation are invoked + * /* + * * Suspends until either 'onCompleted' from the callback is invoked + * * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's activity was destroyed). + * * In both cases, callback will be properly unregistered. + * */ * awaitClose { api.unregister(callback) } * } * ``` - * - * This function is an alias for [channelFlow], it has a separate name to reflect - * the intent of the usage (integration with a callback-based API) better. */ -@Suppress("NOTHING_TO_INLINE") @ExperimentalCoroutinesApi -public inline fun callbackFlow(@BuilderInference noinline block: suspend ProducerScope.() -> Unit): Flow = - channelFlow(block) +public fun callbackFlow(@BuilderInference block: suspend ProducerScope.() -> Unit): Flow = CallbackFlowBuilder(block) // ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow -private class ChannelFlowBuilder( +private open class ChannelFlowBuilder( private val block: suspend ProducerScope.() -> Unit, context: CoroutineContext = EmptyCoroutineContext, capacity: Int = BUFFERED @@ -315,3 +324,31 @@ private class ChannelFlowBuilder( override fun toString(): String = "block[$block] -> ${super.toString()}" } + +private class CallbackFlowBuilder( + private val block: suspend ProducerScope.() -> Unit, + context: CoroutineContext = EmptyCoroutineContext, + capacity: Int = BUFFERED +) : ChannelFlowBuilder(block, context, capacity) { + + override suspend fun collectTo(scope: ProducerScope) { + super.collectTo(scope) + /* + * We expect user either call `awaitClose` from within a block (then the channel is closed at this moment) + * or being closed/cancelled externally/manually. Otherwise "user forgot to call + * awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected. + */ + if (!scope.isClosedForSend) { + throw IllegalStateException( + """ + 'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block. + Otherwise, a callback/listener may leak in case of cancellation external cancellation (e.g. by 'take(1)' or destroyed activity). + See callbackFlow API documentation for the details. + """.trimIndent() + ) + } + } + + override fun create(context: CoroutineContext, capacity: Int): ChannelFlow = + CallbackFlowBuilder(block, context, capacity) +} diff --git a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt index 4711b88418..40df591632 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt @@ -27,6 +27,14 @@ public abstract class ChannelFlow( // buffer capacity between upstream and downstream context @JvmField val capacity: Int ) : Flow { + + // shared code to create a suspend lambda from collectTo function in one place + internal val collectToFun: suspend (ProducerScope) -> Unit + get() = { collectTo(it) } + + protected val produceCapacity: Int + get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity + public fun update( context: CoroutineContext = EmptyCoroutineContext, capacity: Int = Channel.OPTIONAL_CHANNEL @@ -57,13 +65,6 @@ public abstract class ChannelFlow( protected abstract suspend fun collectTo(scope: ProducerScope) - // shared code to create a suspend lambda from collectTo function in one place - internal val collectToFun: suspend (ProducerScope) -> Unit - get() = { collectTo(it) } - - private val produceCapacity: Int - get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity - open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel = scope.broadcast(context, produceCapacity, start, block = collectToFun) @@ -75,11 +76,11 @@ public abstract class ChannelFlow( collector.emitAll(produceImpl(this)) } + open fun additionalToStringProps() = "" + // debug toString override fun toString(): String = "$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]" - - open fun additionalToStringProps() = "" } // ChannelFlow implementation that operates on another flow before it @@ -161,7 +162,7 @@ private suspend fun withContextUndispatched( countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed block: suspend (V) -> T, value: V ): T = - suspendCoroutineUninterceptedOrReturn sc@{ uCont -> + suspendCoroutineUninterceptedOrReturn { uCont -> withCoroutineContext(newContext, countOrElement) { block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) { uCont.resumeWith(it) diff --git a/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt b/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt index bf85c74f64..885f1d6c8f 100644 --- a/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt +++ b/kotlinx-coroutines-core/common/test/channels/ProduceTest.kt @@ -5,6 +5,7 @@ package kotlinx.coroutines.channels import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* import kotlin.coroutines.* import kotlin.test.* @@ -143,9 +144,20 @@ class ProduceTest : TestBase() { @Test fun testAwaitIllegalState() = runTest { - val channel = produce { } - @Suppress("RemoveExplicitTypeArguments") // KT-31525 + val channel = produce { } assertFailsWith { (channel as ProducerScope<*>).awaitClose() } + callbackFlow { + expect(1) + launch { + expect(2) + assertFailsWith { + awaitClose { expectUnreached() } + expectUnreached() + } + } + close() + }.collect() + finish(3) } private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply { diff --git a/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt b/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt index 32c2afc65b..b115150a0b 100644 --- a/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt @@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() { finish(6) } + + @Test + fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) { + val outerScope = this + val flow = channelFlow { + // ~ callback-based API, no children + outerScope.launch(Job()) { + expect(2) + send(1) + expectUnreached() + } + expect(1) + } + assertEquals(emptyList(), flow.toList()) + finish(3) + } + + @Test + fun testNotClosedPrematurely() = runTest { + val outerScope = this + val flow = channelFlow { + // ~ callback-based API + outerScope.launch(Job()) { + expect(2) + send(1) + close() + } + expect(1) + awaitClose() + } + + assertEquals(listOf(1), flow.toList()) + finish(3) + } } diff --git a/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt b/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt index a6b5340555..cfbf242c35 100644 --- a/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt @@ -12,25 +12,35 @@ import kotlin.test.* class FlowCallbackTest : TestBase() { @Test - fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) { + fun testClosedPrematurely() = runTest { val outerScope = this - val flow = channelFlow { + val flow = callbackFlow { // ~ callback-based API outerScope.launch(Job()) { expect(2) - send(1) - expectUnreached() + try { + send(1) + expectUnreached() + } catch (e: IllegalStateException) { + expect(3) + assertTrue(e.message!!.contains("awaitClose")) + } } expect(1) } - assertEquals(emptyList(), flow.toList()) - finish(3) + try { + flow.collect() + } catch (e: IllegalStateException) { + expect(4) + assertTrue(e.message!!.contains("awaitClose")) + } + finish(5) } @Test fun testNotClosedPrematurely() = runTest { val outerScope = this - val flow = channelFlow { + val flow = callbackFlow { // ~ callback-based API outerScope.launch(Job()) { expect(2) diff --git a/kotlinx-coroutines-core/jvm/src/channels/Channels.kt b/kotlinx-coroutines-core/jvm/src/channels/Channels.kt index 78889e70ab..2c9499597f 100644 --- a/kotlinx-coroutines-core/jvm/src/channels/Channels.kt +++ b/kotlinx-coroutines-core/jvm/src/channels/Channels.kt @@ -1,5 +1,5 @@ /* - * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ @file:JvmMultifileClass @@ -9,8 +9,6 @@ package kotlinx.coroutines.channels import kotlinx.coroutines.* -// -------- Operations on SendChannel -------- - /** * Adds [element] into to this channel, **blocking** the caller while this channel [Channel.isFull], * or throws exception if the channel [Channel.isClosedForSend] (see [Channel.close] for details). diff --git a/kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt b/kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt index dab7d5d033..9c37b7bf50 100644 --- a/kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt +++ b/kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt @@ -10,36 +10,12 @@ class AsyncJvmTest : TestBase() { // This must be a common test but it fails on JS because of KT-21961 @Test fun testAsyncWithFinally() = runTest { - expect(1) + launch(Dispatchers.Default) { + + } + + launch(Dispatchers.IO) { - @Suppress("UNREACHABLE_CODE") - val d = async { - expect(3) - try { - yield() // to main, will cancel - } finally { - expect(6) // will go there on await - return@async "Fail" // result will not override cancellation - } - expectUnreached() - "Fail2" - } - expect(2) - yield() // to async - expect(4) - check(d.isActive && !d.isCompleted && !d.isCancelled) - d.cancel() - check(!d.isActive && !d.isCompleted && d.isCancelled) - check(!d.isActive && !d.isCompleted && d.isCancelled) - expect(5) - try { - d.await() // awaits - expectUnreached() // does not complete normally - } catch (e: Throwable) { - expect(7) - check(e is CancellationException) } - check(!d.isActive && d.isCompleted && d.isCancelled) - finish(8) } } diff --git a/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt b/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt index f71040343d..e3db2626ce 100644 --- a/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt +++ b/kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt @@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() { runCatching { it.offer(++i) } } - val flow = channelFlow { + val flow = callbackFlow { api.start(channel) awaitClose { api.stop() @@ -118,7 +118,7 @@ class CallbackFlowTest : TestBase() { } } - private fun Flow.merge(other: Flow): Flow = callbackFlow { + private fun Flow.merge(other: Flow): Flow = channelFlow { launch { collect { send(it) } }