Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect missing awaitClose calls in callbackFlow and close channel wit… #1771

Merged
merged 2 commits into from Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Channel.kt
Expand Up @@ -586,4 +586,4 @@ public class ClosedSendChannelException(message: String?) : IllegalStateExceptio
*
* This exception is a subclass of [NoSuchElementException] to be consistent with plain collections.
*/
public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message)
public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message)
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Expand Up @@ -27,7 +27,7 @@ public interface ProducerScope<in E> : CoroutineScope, SendChannel<E> {

/**
* Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel]
* and invokes the given [block] before resuming the coroutine.
* and invokes the given [block] before resuming the coroutine. This suspending function is cancellable.
*
* Note that when the producer channel is cancelled, this function resumes with a cancellation exception.
* Therefore, in case of cancellation, no code after the call to this function will be executed.
Expand Down
69 changes: 53 additions & 16 deletions kotlinx-coroutines-core/common/src/flow/Builders.kt
Expand Up @@ -11,9 +11,9 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
import kotlinx.coroutines.flow.internal.*
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
import kotlin.coroutines.*
import kotlin.jvm.*
import kotlinx.coroutines.flow.internal.unsafeFlow as flow

/**
* Creates a flow from the given suspendable [block].
Expand Down Expand Up @@ -259,10 +259,16 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
*
* This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used
* from any context, e.g. from a callback-based API.
* The resulting flow completes as soon as the code in the [block] and all its children completes.
* Use [awaitClose] as the last statement to keep it running.
* The [awaitClose] argument is called either when a flow consumer cancels the flow collection
* or when a callback-based API invokes [SendChannel.close] manually.
* The resulting flow completes as soon as the code in the [block] completes.
* [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately
* when block completes.
* [awaitClose] argument is called either when a flow consumer cancels the flow collection
* or when a callback-based API invokes [SendChannel.close] manually and is typically used
* to cleanup the resources after the completion, e.g. unregister a callback.
* Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled,
* otherwise the callback may keep running even when the flow collector is already completed.
qwwdfsad marked this conversation as resolved.
Show resolved Hide resolved
* To avoid such leaks, this method throws [IllegalStateException] if block returns, but the channel
* is not closed yet.
*
* A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the
* resulting flow to specify a user-defined value and to control what happens when data is produced faster
Expand All @@ -277,31 +283,34 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
* fun flowFrom(api: CallbackBasedApi): Flow<T> = callbackFlow {
* val callback = object : Callback { // implementation of some callback interface
* override fun onNextValue(value: T) {
* // Note: offer drops value when buffer is full
* // Use either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill
* offer(value)
* // To avoid blocking you can configure channel capacity using
* // either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill
* try {
* sendBlocking(value)
* } catch (e: Exception) {
* // Handle exception from the channel: failure in flow or premature closing
* }
* }
* override fun onApiError(cause: Throwable) {
* cancel(CancellationException("API Error", cause))
* }
* override fun onCompleted() = channel.close()
* }
* api.register(callback)
* // Suspend until either onCompleted or external cancellation are invoked
* /*
* * Suspends until either 'onCompleted'/'onApiError' from the callback is invoked
* * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's coroutine was cancelled).
* * In both cases, callback will be properly unregistered.
* */
qwwdfsad marked this conversation as resolved.
Show resolved Hide resolved
* awaitClose { api.unregister(callback) }
* }
* ```
*
* This function is an alias for [channelFlow], it has a separate name to reflect
* the intent of the usage (integration with a callback-based API) better.
*/
@Suppress("NOTHING_TO_INLINE")
@ExperimentalCoroutinesApi
public inline fun <T> callbackFlow(@BuilderInference noinline block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
channelFlow(block)
public fun <T> callbackFlow(@BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> = CallbackFlowBuilder(block)
qwwdfsad marked this conversation as resolved.
Show resolved Hide resolved

// ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow
private class ChannelFlowBuilder<T>(
private open class ChannelFlowBuilder<T>(
private val block: suspend ProducerScope<T>.() -> Unit,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = BUFFERED
Expand All @@ -315,3 +324,31 @@ private class ChannelFlowBuilder<T>(
override fun toString(): String =
"block[$block] -> ${super.toString()}"
}

private class CallbackFlowBuilder<T>(
private val block: suspend ProducerScope<T>.() -> Unit,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = BUFFERED
) : ChannelFlowBuilder<T>(block, context, capacity) {

override suspend fun collectTo(scope: ProducerScope<T>) {
super.collectTo(scope)
/*
* We expect user either call `awaitClose` from within a block (then the channel is closed at this moment)
* or being closed/cancelled externally/manually. Otherwise "user forgot to call
* awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected.
*/
if (!scope.isClosedForSend) {
throw IllegalStateException(
"""
'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block.
Otherwise, a callback/listener may leak in case of external cancellation.
See callbackFlow API documentation for the details.
""".trimIndent()
)
}
}

override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
CallbackFlowBuilder(block, context, capacity)
}
21 changes: 11 additions & 10 deletions kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt
Expand Up @@ -27,6 +27,14 @@ public abstract class ChannelFlow<T>(
// buffer capacity between upstream and downstream context
@JvmField val capacity: Int
) : Flow<T> {

// shared code to create a suspend lambda from collectTo function in one place
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

private val produceCapacity: Int
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity

public fun update(
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = Channel.OPTIONAL_CHANNEL
Expand Down Expand Up @@ -57,13 +65,6 @@ public abstract class ChannelFlow<T>(

protected abstract suspend fun collectTo(scope: ProducerScope<T>)

// shared code to create a suspend lambda from collectTo function in one place
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

private val produceCapacity: Int
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity

open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
scope.broadcast(context, produceCapacity, start, block = collectToFun)

Expand All @@ -75,11 +76,11 @@ public abstract class ChannelFlow<T>(
collector.emitAll(produceImpl(this))
}

open fun additionalToStringProps() = ""

// debug toString
override fun toString(): String =
"$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]"

open fun additionalToStringProps() = ""
}

// ChannelFlow implementation that operates on another flow before it
Expand Down Expand Up @@ -161,7 +162,7 @@ private suspend fun <T, V> withContextUndispatched(
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
block: suspend (V) -> T, value: V
): T =
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
suspendCoroutineUninterceptedOrReturn { uCont ->
withCoroutineContext(newContext, countOrElement) {
block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) {
uCont.resumeWith(it)
Expand Down
16 changes: 14 additions & 2 deletions kotlinx-coroutines-core/common/test/channels/ProduceTest.kt
Expand Up @@ -5,6 +5,7 @@
package kotlinx.coroutines.channels

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlin.coroutines.*
import kotlin.test.*

Expand Down Expand Up @@ -143,9 +144,20 @@ class ProduceTest : TestBase() {

@Test
fun testAwaitIllegalState() = runTest {
val channel = produce<Int> { }
@Suppress("RemoveExplicitTypeArguments") // KT-31525
val channel = produce<Int> { }
assertFailsWith<IllegalStateException> { (channel as ProducerScope<*>).awaitClose() }
callbackFlow<Unit> {
expect(1)
launch {
expect(2)
assertFailsWith<IllegalStateException> {
awaitClose { expectUnreached() }
expectUnreached()
}
}
close()
}.collect()
finish(3)
}

private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply {
Expand Down
Expand Up @@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() {

finish(6)
}

@Test
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
val outerScope = this
val flow = channelFlow {
// ~ callback-based API, no children
outerScope.launch(Job()) {
expect(2)
send(1)
expectUnreached()
}
expect(1)
}
assertEquals(emptyList(), flow.toList())
finish(3)
}

@Test
fun testNotClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
send(1)
close()
}
expect(1)
awaitClose()
}

assertEquals(listOf(1), flow.toList())
finish(3)
}
}
Expand Up @@ -12,25 +12,35 @@ import kotlin.test.*

class FlowCallbackTest : TestBase() {
@Test
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
fun testClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow {
val flow = callbackFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
send(1)
expectUnreached()
try {
send(1)
expectUnreached()
} catch (e: IllegalStateException) {
expect(3)
assertTrue(e.message!!.contains("awaitClose"))
}
}
expect(1)
}
assertEquals(emptyList(), flow.toList())
finish(3)
try {
flow.collect()
} catch (e: IllegalStateException) {
expect(4)
assertTrue(e.message!!.contains("awaitClose"))
}
finish(5)
}

@Test
fun testNotClosedPrematurely() = runTest {
val outerScope = this
val flow = channelFlow<Int> {
val flow = callbackFlow {
// ~ callback-based API
outerScope.launch(Job()) {
expect(2)
Expand Down
4 changes: 1 addition & 3 deletions kotlinx-coroutines-core/jvm/src/channels/Channels.kt
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:JvmMultifileClass
Expand All @@ -9,8 +9,6 @@ package kotlinx.coroutines.channels

import kotlinx.coroutines.*

// -------- Operations on SendChannel --------

/**
* Adds [element] into to this channel, **blocking** the caller while this channel [Channel.isFull],
* or throws exception if the channel [Channel.isClosedForSend] (see [Channel.close] for details).
Expand Down
34 changes: 5 additions & 29 deletions kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt
Expand Up @@ -10,36 +10,12 @@ class AsyncJvmTest : TestBase() {
// This must be a common test but it fails on JS because of KT-21961
@Test
fun testAsyncWithFinally() = runTest {
expect(1)
launch(Dispatchers.Default) {

}

launch(Dispatchers.IO) {

@Suppress("UNREACHABLE_CODE")
val d = async {
expect(3)
try {
yield() // to main, will cancel
} finally {
expect(6) // will go there on await
return@async "Fail" // result will not override cancellation
}
expectUnreached()
"Fail2"
}
expect(2)
yield() // to async
expect(4)
check(d.isActive && !d.isCompleted && !d.isCancelled)
d.cancel()
check(!d.isActive && !d.isCompleted && d.isCancelled)
check(!d.isActive && !d.isCompleted && d.isCancelled)
expect(5)
try {
d.await() // awaits
expectUnreached() // does not complete normally
} catch (e: Throwable) {
expect(7)
check(e is CancellationException)
}
check(!d.isActive && d.isCompleted && d.isCancelled)
finish(8)
}
}
4 changes: 2 additions & 2 deletions kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt
Expand Up @@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() {
runCatching { it.offer(++i) }
}

val flow = channelFlow<Int> {
val flow = callbackFlow<Int> {
api.start(channel)
awaitClose {
api.stop()
Expand Down Expand Up @@ -118,7 +118,7 @@ class CallbackFlowTest : TestBase() {
}
}

private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = callbackFlow {
private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
launch {
collect { send(it) }
}
Expand Down