Skip to content

Commit

Permalink
Detect missing awaitClose calls in callbackFlow and close channel wit…
Browse files Browse the repository at this point in the history
…h a proper diagnostic exception

Fixes #1762
Fixes #1770
  • Loading branch information
qwwdfsad committed Jan 20, 2020
1 parent f18e0e4 commit 4249758
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 35 deletions.
Expand Up @@ -987,6 +987,7 @@ public abstract class kotlinx/coroutines/flow/internal/ChannelFlow : kotlinx/cor
public fun collect (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
protected abstract fun collectTo (Lkotlinx/coroutines/channels/ProducerScope;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
protected abstract fun create (Lkotlin/coroutines/CoroutineContext;I)Lkotlinx/coroutines/flow/internal/ChannelFlow;
protected final fun getProduceCapacity ()I
public fun produceImpl (Lkotlinx/coroutines/CoroutineScope;)Lkotlinx/coroutines/channels/ReceiveChannel;
public fun toString ()Ljava/lang/String;
public final fun update (Lkotlin/coroutines/CoroutineContext;I)Lkotlinx/coroutines/flow/internal/ChannelFlow;
Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Expand Up @@ -27,7 +27,7 @@ public interface ProducerScope<in E> : CoroutineScope, SendChannel<E> {

/**
* Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel]
* and invokes the given [block] before resuming the coroutine.
* and invokes the given [block] before resuming the coroutine. This suspending function is cancellable.
*
* Note that when the producer channel is cancelled, this function resumes with a cancellation exception.
* Therefore, in case of cancellation, no code after the call to this function will be executed.
Expand Down
62 changes: 49 additions & 13 deletions kotlinx-coroutines-core/common/src/flow/Builders.kt
Expand Up @@ -11,9 +11,9 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
import kotlinx.coroutines.flow.internal.*
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
import kotlin.coroutines.*
import kotlin.jvm.*
import kotlinx.coroutines.flow.internal.unsafeFlow as flow

/**
* Creates a flow from the given suspendable [block].
Expand Down Expand Up @@ -259,10 +259,14 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
*
* This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used
* from any context, e.g. from a callback-based API.
* The resulting flow completes as soon as the code in the [block] and all its children completes.
* Use [awaitClose] as the last statement to keep it running.
* The [awaitClose] argument is called either when a flow consumer cancels the flow collection
* or when a callback-based API invokes [SendChannel.close] manually.
* The resulting flow completes as soon as the code in the [block] completes.
* [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately
* when block completes.
* [awaitClose] argument is called either when a flow consumer cancels the flow collection
* or when a callback-based API invokes [SendChannel.close] manually and is typically used
* to cleanup the resources after the completion, e.g. unregister a callback.
* Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled,
* otherwise the callback may keep running even when the flow collector is already completed.
*
* A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the
* resulting flow to specify a user-defined value and to control what happens when data is produced faster
Expand All @@ -287,21 +291,20 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
* override fun onCompleted() = channel.close()
* }
* api.register(callback)
* // Suspend until either onCompleted or external cancellation are invoked
* /*
* * Suspends until either 'onCompleted' from the callback is invoked
* * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's activity was destroyed).
* * In both cases, callback will be properly unregistered.
* */
* awaitClose { api.unregister(callback) }
* }
* ```
*
* This function is an alias for [channelFlow], it has a separate name to reflect
* the intent of the usage (integration with a callback-based API) better.
*/
@Suppress("NOTHING_TO_INLINE")
@ExperimentalCoroutinesApi
public inline fun <T> callbackFlow(@BuilderInference noinline block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
channelFlow(block)
public fun <T> callbackFlow(@BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> = CallbackFlowBuilder(block)

// ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow
private class ChannelFlowBuilder<T>(
private open class ChannelFlowBuilder<T>(
private val block: suspend ProducerScope<T>.() -> Unit,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = BUFFERED
Expand All @@ -315,3 +318,36 @@ private class ChannelFlowBuilder<T>(
override fun toString(): String =
"block[$block] -> ${super.toString()}"
}

private class CallbackFlowBuilder<T>(
private val block: suspend ProducerScope<T>.() -> Unit,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = BUFFERED
) : ChannelFlowBuilder<T>(block, context, capacity) {

private val collectCallback: suspend (ProducerScope<T>) -> Unit = {
collectTo(it)
/*
* We expect user either call `awaitClose` from within a block (then the channel is closed at this moment)
* or being closed/cancelled externally/manually. Otherwise "user forgot to call
* awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected.
*/
if (it.isActive && !it.isClosedForSend) {
throw IllegalStateException(
"""
'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block.
Otherwise, a callback/listener may leak in case of cancellation external cancellation (e.g. by 'take(1)' or destroyed activity).
For a more detailed explanation, please refer to callbackFlow KDoc.
""".trimIndent())
}
}

override fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
scope.broadcast(context, produceCapacity, start, block = collectCallback)

override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> =
scope.produce(context, produceCapacity, block = collectCallback)

override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
CallbackFlowBuilder(block, context, capacity)
}
21 changes: 11 additions & 10 deletions kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt
Expand Up @@ -27,6 +27,14 @@ public abstract class ChannelFlow<T>(
// buffer capacity between upstream and downstream context
@JvmField val capacity: Int
) : Flow<T> {

// shared code to create a suspend lambda from collectTo function in one place
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

protected val produceCapacity: Int
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity

public fun update(
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = Channel.OPTIONAL_CHANNEL
Expand Down Expand Up @@ -57,13 +65,6 @@ public abstract class ChannelFlow<T>(

protected abstract suspend fun collectTo(scope: ProducerScope<T>)

// shared code to create a suspend lambda from collectTo function in one place
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

private val produceCapacity: Int
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity

open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
scope.broadcast(context, produceCapacity, start, block = collectToFun)

Expand All @@ -75,11 +76,11 @@ public abstract class ChannelFlow<T>(
collector.emitAll(produceImpl(this))
}

open fun additionalToStringProps() = ""

// debug toString
override fun toString(): String =
"$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]"

open fun additionalToStringProps() = ""
}

// ChannelFlow implementation that operates on another flow before it
Expand Down Expand Up @@ -161,7 +162,7 @@ private suspend fun <T, V> withContextUndispatched(
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
block: suspend (V) -> T, value: V
): T =
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
suspendCoroutineUninterceptedOrReturn { uCont ->
withCoroutineContext(newContext, countOrElement) {
block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) {
uCont.resumeWith(it)
Expand Down
16 changes: 14 additions & 2 deletions kotlinx-coroutines-core/common/test/channels/ProduceTest.kt
Expand Up @@ -5,6 +5,7 @@
package kotlinx.coroutines.channels

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlin.coroutines.*
import kotlin.test.*

Expand Down Expand Up @@ -143,9 +144,20 @@ class ProduceTest : TestBase() {

@Test
fun testAwaitIllegalState() = runTest {
val channel = produce<Int> { }
@Suppress("RemoveExplicitTypeArguments") // KT-31525
val channel = produce<Int> { }
assertFailsWith<IllegalStateException> { (channel as ProducerScope<*>).awaitClose() }
callbackFlow<Unit> {
expect(1)
launch {
expect(2)
assertFailsWith<IllegalStateException> {
awaitClose { expectUnreached() }
expectUnreached()
}
}
close()
}.collect()
finish(3)
}

private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply {
Expand Down
Expand Up @@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() {

finish(6)
}

@Test
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
val outerScope = this
val flow = channelFlow {
// ~ callback-based API, no children
outerScope.launch(Job()) {
expect(2)
send(1)
expectUnreached()
}
expect(1)
}
assertEquals(emptyList(), flow.toList())
finish(3)
}

@Test
fun testNotClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
send(1)
close()
}
expect(1)
awaitClose()
}

assertEquals(listOf(1), flow.toList())
finish(3)
}
}
Expand Up @@ -12,25 +12,35 @@ import kotlin.test.*

class FlowCallbackTest : TestBase() {
@Test
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
fun testClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow {
val flow = callbackFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
send(1)
expectUnreached()
try {
send(1)
expectUnreached()
} catch (e: IllegalStateException) {
expect(3)
assertTrue(e.message!!.contains("awaitClose"))
}
}
expect(1)
}
assertEquals(emptyList(), flow.toList())
finish(3)
try {
flow.collect()
} catch (e: IllegalStateException) {
expect(4)
assertTrue(e.message!!.contains("awaitClose"))
}
finish(5)
}

@Test
fun testNotClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow<Int> {
val flow = callbackFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
Expand Down
4 changes: 2 additions & 2 deletions kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt
Expand Up @@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() {
runCatching { it.offer(++i) }
}

val flow = channelFlow<Int> {
val flow = callbackFlow<Int> {
api.start(channel)
awaitClose {
api.stop()
Expand Down Expand Up @@ -118,7 +118,7 @@ class CallbackFlowTest : TestBase() {
}
}

private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = callbackFlow {
private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
launch {
collect { send(it) }
}
Expand Down

0 comments on commit 4249758

Please sign in to comment.