Skip to content

Commit

Permalink
Make SafeCollector platform-specific declaration and enforce exceptio… (
Browse files Browse the repository at this point in the history
#1793)

  
* Make SafeCollector platform-specific declaration and enforce exception transparency invariant on JVM
    * Make it in an allocation-free manner by using a crafty trick with casting KSuspendFunction to Function and pass a reusable object as a completion

Fixes #1657
  • Loading branch information
qwwdfsad committed Feb 13, 2020
1 parent b64a23b commit de491d2
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 156 deletions.
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Expand Up @@ -1001,7 +1001,7 @@ public final class kotlinx/coroutines/flow/internal/FlowExceptions_commonKt {
public static final fun checkIndexOverflow (I)I
}

public final class kotlinx/coroutines/flow/internal/SafeCollectorKt {
public final class kotlinx/coroutines/flow/internal/SafeCollector_commonKt {
public static final fun unsafeFlow (Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
}

Expand Down
7 changes: 6 additions & 1 deletion kotlinx-coroutines-core/common/src/flow/Builders.kt
Expand Up @@ -51,7 +51,12 @@ public fun <T> flow(@BuilderInference block: suspend FlowCollector<T>.() -> Unit
// Named anonymous object
private class SafeFlow<T>(private val block: suspend FlowCollector<T>.() -> Unit) : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>) {
SafeCollector(collector, coroutineContext).block()
val safeCollector = SafeCollector(collector, coroutineContext)
try {
safeCollector.block()
} finally {
safeCollector.releaseIntercepted()
}
}
}

Expand Down
13 changes: 9 additions & 4 deletions kotlinx-coroutines-core/common/src/flow/Flow.kt
Expand Up @@ -5,7 +5,7 @@
package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.SafeCollector
import kotlinx.coroutines.flow.internal.*
import kotlin.coroutines.*

/**
Expand Down Expand Up @@ -149,8 +149,8 @@ import kotlin.coroutines.*
* it hard to reason about the code because an exception in the `collect { ... }` could be somehow "caught"
* by an upstream flow, limiting the ability of local reasoning about the code.
*
* Currently, the flow infrastructure does not enforce exception transparency contracts, however, it might be enforced
* in the future either at run time or at compile time.
* Flow machinery enforces exception transparency at runtime and throws [IllegalStateException] on any attempt to emit a value,
* if an exception has been thrown on previous attempt.
*
* ### Reactive streams
*
Expand Down Expand Up @@ -199,7 +199,12 @@ public abstract class AbstractFlow<T> : Flow<T> {

@InternalCoroutinesApi
public final override suspend fun collect(collector: FlowCollector<T>) {
collectSafely(SafeCollector(collector, collectContext = coroutineContext))
val safeCollector = SafeCollector(collector, coroutineContext)
try {
collectSafely(safeCollector)
} finally {
safeCollector.releaseIntercepted()
}
}

/**
Expand Down
@@ -0,0 +1,111 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.ScopeCoroutine
import kotlin.coroutines.*
import kotlin.jvm.*

internal expect class SafeCollector<T>(
collector: FlowCollector<T>,
collectContext: CoroutineContext
) : FlowCollector<T> {
internal val collector: FlowCollector<T>
internal val collectContext: CoroutineContext
internal val collectContextSize: Int
public fun releaseIntercepted()
}

@JvmName("checkContext") // For prettier stack traces
internal fun SafeCollector<*>.checkContext(currentContext: CoroutineContext) {
val result = currentContext.fold(0) fold@{ count, element ->
val key = element.key
val collectElement = collectContext[key]
if (key !== Job) {
return@fold if (element !== collectElement) Int.MIN_VALUE
else count + 1
}

val collectJob = collectElement as Job?
val emissionParentJob = (element as Job).transitiveCoroutineParent(collectJob)
/*
* Code like
* ```
* coroutineScope {
* launch {
* emit(1)
* }
*
* launch {
* emit(2)
* }
* }
* ```
* is prohibited because 'emit' is not thread-safe by default. Use 'channelFlow' instead if you need concurrent emission
* or want to switch context dynamically (e.g. with `withContext`).
*
* Note that collecting from another coroutine is allowed, e.g.:
* ```
* coroutineScope {
* val channel = produce {
* collect { value ->
* send(value)
* }
* }
* channel.consumeEach { value ->
* emit(value)
* }
* }
* ```
* is a completely valid.
*/
if (emissionParentJob !== collectJob) {
error(
"Flow invariant is violated:\n" +
"\t\tEmission from another coroutine is detected.\n" +
"\t\tChild of $emissionParentJob, expected child of $collectJob.\n" +
"\t\tFlowCollector is not thread-safe and concurrent emissions are prohibited.\n" +
"\t\tTo mitigate this restriction please use 'channelFlow' builder instead of 'flow'"
)
}

/*
* If collect job is null (-> EmptyCoroutineContext, probably run from `suspend fun main`), then invariant is maintained
* (common transitive parent is "null"), but count check will fail, so just do not count job context element when
* flow is collected from EmptyCoroutineContext
*/
if (collectJob == null) count else count + 1
}
if (result != collectContextSize) {
error(
"Flow invariant is violated:\n" +
"\t\tFlow was collected in $collectContext,\n" +
"\t\tbut emission happened in $currentContext.\n" +
"\t\tPlease refer to 'flow' documentation or use 'flowOn' instead"
)
}
}

internal tailrec fun Job?.transitiveCoroutineParent(collectJob: Job?): Job? {
if (this === null) return null
if (this === collectJob) return this
if (this !is ScopeCoroutine<*>) return this
return parent.transitiveCoroutineParent(collectJob)
}

/**
* An analogue of the [flow] builder that does not check the context of execution of the resulting flow.
* Used in our own operators where we trust the context of invocations.
*/
@PublishedApi
internal inline fun <T> unsafeFlow(@BuilderInference crossinline block: suspend FlowCollector<T>.() -> Unit): Flow<T> {
return object : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>) {
collector.block()
}
}
}
124 changes: 0 additions & 124 deletions kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt

This file was deleted.

14 changes: 12 additions & 2 deletions kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt
Expand Up @@ -71,7 +71,12 @@ internal inline fun <T, R> Flow<T>.unsafeTransform(
public fun <T> Flow<T>.onStart(
action: suspend FlowCollector<T>.() -> Unit
): Flow<T> = unsafeFlow { // Note: unsafe flow is used here, but safe collector is used to invoke start action
SafeCollector<T>(this, coroutineContext).action()
val safeCollector = SafeCollector<T>(this, coroutineContext)
try {
safeCollector.action()
} finally {
safeCollector.releaseIntercepted()
}
collect(this) // directly delegate
}

Expand Down Expand Up @@ -141,7 +146,12 @@ public fun <T> Flow<T>.onCompletion(
throw e
}
// Exception from the upstream or normal completion
SafeCollector(this, coroutineContext).invokeSafely(action, exception)
val safeCollector = SafeCollector(this, coroutineContext)
try {
safeCollector.invokeSafely(action, exception)
} finally {
safeCollector.releaseIntercepted()
}
exception?.let { throw it }
}

Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/test/TestBase.common.kt
Expand Up @@ -50,7 +50,7 @@ public suspend inline fun <reified T : Throwable> assertFailsWith(flow: Flow<*>)
flow.collect()
fail("Should be unreached")
} catch (e: Throwable) {
assertTrue(e is T)
assertTrue(e is T, "Expected exception ${T::class}, but had $e instead")
}
}

Expand Down
27 changes: 4 additions & 23 deletions kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt
Expand Up @@ -59,25 +59,6 @@ class FlowInvariantsTest : TestBase() {
}
}

@Test
fun testCachedInvariantCheckResult() = runParametrizedTest<Int> { flow ->
flow {
emit(1)
try {
withContext(NamedDispatchers("foo")) {
emit(1)
}
fail()
} catch (e: IllegalStateException) {
expect(2)
}
emit(3)
}.collect {
expect(it)
}
finish(4)
}

@Test
fun testWithNameContractViolated() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
flow {
Expand Down Expand Up @@ -146,9 +127,9 @@ class FlowInvariantsTest : TestBase() {
}
}

val flow = flowOf(1)
assertFailsWith<IllegalStateException> { flow.merge(flow).toList() }
assertFailsWith<IllegalStateException> { flow.trickyMerge(flow).toList() }
val flowInstance = flowOf(1)
assertFailsWith<IllegalStateException> { flowInstance.merge(flowInstance).toList() }
assertFailsWith<IllegalStateException> { flowInstance.trickyMerge(flowInstance).toList() }
}

@Test
Expand Down Expand Up @@ -237,7 +218,7 @@ class FlowInvariantsTest : TestBase() {
emptyContextTest {
transform {
expect(it)
kotlinx.coroutines.withContext(Dispatchers.Unconfined) {
withContext(Dispatchers.Unconfined) {
emit(it + 1)
}
}
Expand Down

0 comments on commit de491d2

Please sign in to comment.