Skip to content

Commit

Permalink
Restore thread context elements when directly resuming to parent.
Browse files Browse the repository at this point in the history
This fix solves the problem of restoring thread-context when
returning to another context in undispatched way.

It impacts suspend/resume performance of coroutines that use ThreadContextElement and undispatched coroutines.

The kotlinx.coroutines code poisons the context with special 'UndispatchedMarker' element and linear lookup is performed only when the marker is present. The code also contains description of an alternative approach in order to save a linear lookup in complex coroutines hierarchies.

Fast-path of coroutine resumption is slowed down by a single context lookup.

Fixes #985

Co-authored-by: Roman Elizarov <elizarov@gmail.com>
  • Loading branch information
qwwdfsad and elizarov committed Feb 1, 2021
1 parent 7061cc2 commit 5dc55a6
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 22 deletions.
4 changes: 2 additions & 2 deletions benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt
Expand Up @@ -10,7 +10,7 @@ import org.openjdk.jmh.annotations.*
import java.util.concurrent.*
import kotlin.coroutines.*

@Warmup(iterations = 5, time = 1)
@Warmup(iterations = 7, time = 1)
@Measurement(iterations = 5, time = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
Expand Down Expand Up @@ -41,7 +41,7 @@ open class ChannelSinkBenchmark {

private suspend inline fun run(context: CoroutineContext): Int {
return Channel
.range(1, 1_000_000, context)
.range(1, 10_000, context)
.filter(context) { it % 4 == 0 }
.fold(0) { a, b -> a + b }
}
Expand Down
91 changes: 91 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkDepthBenchmark.kt
@@ -0,0 +1,91 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*
import kotlin.coroutines.*

@Warmup(iterations = 7, time = 1)
@Measurement(iterations = 5, time = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Fork(2)
open class ChannelSinkDepthBenchmark {
private val tl = ThreadLocal.withInitial({ 42 })

private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement()

@Benchmark
fun depth1(): Int = runBlocking {
run(1, unconfinedOneElement)
}

@Benchmark
fun depth10(): Int = runBlocking {
run(10, unconfinedOneElement)
}

@Benchmark
fun depth100(): Int = runBlocking {
run(100, unconfinedOneElement)
}

@Benchmark
fun depth1000(): Int = runBlocking {
run(1000, unconfinedOneElement)
}

private suspend inline fun run(callTraceDepth: Int, context: CoroutineContext): Int {
return Channel
.range(1, 10_000, context)
.filter(callTraceDepth, context) { it % 4 == 0 }
.fold(0) { a, b -> a + b }
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) =
GlobalScope.produce(context) {
for (i in start until (start + count))
send(i)
}

// Migrated from deprecated operators, are good only for stressing channels

private fun ReceiveChannel<Int>.filter(
callTraceDepth: Int,
context: CoroutineContext = Dispatchers.Unconfined,
predicate: suspend (Int) -> Boolean
): ReceiveChannel<Int> =
GlobalScope.produce(context, onCompletion = { cancel() }) {
deeplyNestedFilter(this, callTraceDepth, predicate)
}

private suspend fun ReceiveChannel<Int>.deeplyNestedFilter(
sink: ProducerScope<Int>,
depth: Int,
predicate: suspend (Int) -> Boolean
) {
if (depth <= 1) {
for (e in this) {
if (predicate(e)) sink.send(e)
}
} else {
deeplyNestedFilter(sink, depth - 1, predicate)
require(true) // tail-call
}
}

private suspend inline fun <E, R> ReceiveChannel<E>.fold(initial: R, operation: (acc: R, E) -> R): R {
var accumulator = initial
consumeEach {
accumulator = operation(accumulator, it)
}
return accumulator
}
}

14 changes: 3 additions & 11 deletions kotlinx-coroutines-core/common/src/Builders.common.kt
Expand Up @@ -207,25 +207,17 @@ private class LazyStandaloneCoroutine(
}

// Used by withContext when context changes, but dispatcher stays the same
private class UndispatchedCoroutine<in T>(
internal expect class UndispatchedCoroutine<in T>(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
override fun afterResume(state: Any?) {
// resume undispatched -- update context by stay on the same dispatcher
val result = recoverResult(state, uCont)
withCoroutineContext(uCont.context, null) {
uCont.resumeWith(result)
}
}
}
) : ScopeCoroutine<T>

private const val UNDECIDED = 0
private const val SUSPENDED = 1
private const val RESUMED = 2

// Used by withContext when context dispatcher changes
private class DispatchedCoroutine<in T>(
internal class DispatchedCoroutine<in T>(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
Expand Down
Expand Up @@ -19,5 +19,6 @@ internal expect val DefaultDelay: Delay

// countOrElement -- pre-cached value for ThreadContext.kt
internal expect inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T
internal expect inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T
internal expect fun Continuation<*>.toDebugString(): String
internal expect val CoroutineContext.coroutineName: String?
Expand Up @@ -235,7 +235,7 @@ internal class DispatchedContinuation<in T>(

@Suppress("NOTHING_TO_INLINE") // we need it inline to save us an entry on the stack
inline fun resumeUndispatchedWith(result: Result<T>) {
withCoroutineContext(context, countOrElement) {
withContinuationContext(continuation, countOrElement) {
continuation.resumeWith(result)
}
}
Expand Down
6 changes: 3 additions & 3 deletions kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt
Expand Up @@ -85,9 +85,9 @@ internal abstract class DispatchedTask<in T>(
try {
val delegate = delegate as DispatchedContinuation<T>
val continuation = delegate.continuation
val context = continuation.context
val state = takeState() // NOTE: Must take state in any case, even if cancelled
withCoroutineContext(context, delegate.countOrElement) {
withContinuationContext(continuation, delegate.countOrElement) {
val context = continuation.context
val state = takeState() // NOTE: Must take state in any case, even if cancelled
val exception = getExceptionalResult(state)
/*
* Check whether continuation was originally resumed with an exception.
Expand Down
9 changes: 9 additions & 0 deletions kotlinx-coroutines-core/js/src/CoroutineContext.kt
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines

import kotlinx.coroutines.internal.*
import kotlin.browser.*
import kotlin.coroutines.*

Expand Down Expand Up @@ -49,5 +50,13 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext):

// No debugging facilities on JS
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()
internal actual fun Continuation<*>.toDebugString(): String = toString()
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on JS

internal actual class UndispatchedCoroutine<in T> actual constructor(
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont))
}
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/jvm/src/Builders.kt
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:JvmMultifileClass
Expand Down
97 changes: 97 additions & 0 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Expand Up @@ -7,6 +7,7 @@ package kotlinx.coroutines
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.scheduling.*
import kotlin.coroutines.*
import kotlin.coroutines.jvm.internal.CoroutineStackFrame

internal const val COROUTINES_SCHEDULER_PROPERTY_NAME = "kotlinx.coroutines.scheduler"

Expand Down Expand Up @@ -47,6 +48,102 @@ internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, c
}
}

/**
* Executes a block using a context of a given continuation.
*/
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
val context = continuation.context
val oldValue = updateThreadContext(context, countOrElement)
val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
// Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
continuation.updateUndispatchedCompletion(context, oldValue)
} else {
null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
}
try {
return block()
} finally {
if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
restoreThreadContext(context, oldValue)
}
}
}

internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
if (this !is CoroutineStackFrame) return null
/*
* Fast-path to detect whether we have unispatched coroutine at all in our stack.
*
* Implementation note.
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
* 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
* 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
* from the context when creating dispatched coroutine in `withContext`.
* Another option is to "unmark it" instead of removing to save an allocation.
* Both options should work, but it requires more careful studying of the performance
* and, mostly, maintainability impact.
*/
val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null
if (!potentiallyHasUndispatchedCorotuine) return null
val completion = undispatchedCompletion()
completion?.saveThreadContext(context, oldValue)
return completion
}

internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
// Find direct completion of this continuation
val completion: CoroutineStackFrame = when (this) {
is DispatchedCoroutine<*> -> return null
else -> callerFrame ?: return null // something else -- not supported
}
if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
return completion.undispatchedCompletion() // walk up the call stack with tail call
}

/**
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
* Used as a performance optimization to avoid stack walking where it is not nesessary.
*/
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
override val key: CoroutineContext.Key<*>
get() = this
}

// Used by withContext when context changes, but dispatcher stays the same
internal actual class UndispatchedCoroutine<in T>actual constructor (
context: CoroutineContext,
uCont: Continuation<T>
) : ScopeCoroutine<T>(context + UndispatchedMarker, uCont) {

private var savedContext: CoroutineContext? = null
private var savedOldValue: Any? = null

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
savedContext = context
savedOldValue = oldValue
}

fun clearThreadContext(): Boolean {
if (savedContext == null) return false
savedContext = null
savedOldValue = null
return true
}

override fun afterResume(state: Any?) {
savedContext?.let { context ->
restoreThreadContext(context, savedOldValue)
savedContext = null
savedOldValue = null
}
// resume undispatched -- update context but stay on the same dispatcher
val result = recoverResult(state, uCont)
withContinuationContext(uCont, null) {
uCont.resumeWith(result)
}
}
}

internal actual val CoroutineContext.coroutineName: String? get() {
if (!DEBUG) return null
val coroutineId = this[CoroutineId] ?: return null
Expand Down
9 changes: 5 additions & 4 deletions kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
Expand Up @@ -7,8 +7,8 @@ package kotlinx.coroutines.internal
import kotlinx.coroutines.*
import kotlin.coroutines.*


private val ZERO = Symbol("ZERO")
@JvmField
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")

// Used when there are >= 2 active elements in the context
private class ThreadState(val context: CoroutineContext, n: Int) {
Expand Down Expand Up @@ -60,12 +60,13 @@ private val restoreState =
internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!

// countOrElement is pre-cached in dispatched continuation
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
@Suppress("NAME_SHADOWING")
val countOrElement = countOrElement ?: threadContextElements(context)
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
return when {
countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements
countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
// ^^^ identity comparison for speed, we know zero always has the same identity
countOrElement is Int -> {
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
Expand All @@ -82,7 +83,7 @@ internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?

internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
when {
oldState === ZERO -> return // very fast path when there are no ThreadContextElements
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.start()
Expand Down

0 comments on commit 5dc55a6

Please sign in to comment.