From 5d0298fd5faaf1f1a10c1c4929bb032ba487c26d Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Fri, 27 Sep 2019 12:14:40 +0300 Subject: [PATCH] Restore thread context elements when directly resuming to parent This fix solves the problem of restoring thread-context when returning to another context in undispatched way. It impacts suspend/resume performance of coroutines that use ThreadContextElement since we have to walk up the coroutine completion stack in search for parent UndispatchedCoroutine. However, there is a fast-path to ensure that there is no performance impact in cases when ThreadContextElement is not used by a coroutine. Fixes #985 --- .../common/src/Builders.common.kt | 14 +- .../common/src/CoroutineContext.common.kt | 1 + .../src/internal/DispatchedContinuation.kt | 4 +- .../common/src/internal/DispatchedTask.kt | 6 +- .../js/src/CoroutineContext.kt | 9 + .../jvm/src/CoroutineContext.kt | 70 ++++++- .../jvm/src/internal/ThreadContext.kt | 9 +- .../test/ThreadContextElementRestoreTest.kt | 198 ++++++++++++++++++ .../native/src/CoroutineContext.kt | 9 + 9 files changed, 299 insertions(+), 21 deletions(-) create mode 100644 kotlinx-coroutines-core/jvm/test/ThreadContextElementRestoreTest.kt diff --git a/kotlinx-coroutines-core/common/src/Builders.common.kt b/kotlinx-coroutines-core/common/src/Builders.common.kt index 6ef1a8daea..25ade3527f 100644 --- a/kotlinx-coroutines-core/common/src/Builders.common.kt +++ b/kotlinx-coroutines-core/common/src/Builders.common.kt @@ -207,25 +207,17 @@ private class LazyStandaloneCoroutine( } // Used by withContext when context changes, but dispatcher stays the same -private class UndispatchedCoroutine( +internal expect class UndispatchedCoroutine( context: CoroutineContext, uCont: Continuation -) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) { - // resume undispatched -- update context by stay on the same dispatcher - val result = recoverResult(state, uCont) - withCoroutineContext(uCont.context, null) { - uCont.resumeWith(result) - } - } -} +) : ScopeCoroutine private const val UNDECIDED = 0 private const val SUSPENDED = 1 private const val RESUMED = 2 // Used by withContext when context dispatcher changes -private class DispatchedCoroutine( +internal class DispatchedCoroutine( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt index 51374603c3..17ad66c19a 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt @@ -19,5 +19,6 @@ internal expect val DefaultDelay: Delay // countOrElement -- pre-cached value for ThreadContext.kt internal expect inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T +internal expect inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T internal expect fun Continuation<*>.toDebugString(): String internal expect val CoroutineContext.coroutineName: String? diff --git a/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt b/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt index b7b2954f6a..20b77bfe2c 100644 --- a/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt +++ b/kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt @@ -23,7 +23,7 @@ internal class DispatchedContinuation( @JvmField @Suppress("PropertyName") internal var _state: Any? = UNDEFINED - override val callerFrame: CoroutineStackFrame? = continuation as? CoroutineStackFrame + override val callerFrame: CoroutineStackFrame? get() = continuation as? CoroutineStackFrame override fun getStackTraceElement(): StackTraceElement? = null @JvmField // pre-cached value to avoid ctx.fold on every resumption internal val countOrElement = threadContextElements(context) @@ -235,7 +235,7 @@ internal class DispatchedContinuation( @Suppress("NOTHING_TO_INLINE") // we need it inline to save us an entry on the stack inline fun resumeUndispatchedWith(result: Result) { - withCoroutineContext(context, countOrElement) { + withContinuationContext(continuation, countOrElement) { continuation.resumeWith(result) } } diff --git a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt index caf87f143e..ce05979db6 100644 --- a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt +++ b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt @@ -85,9 +85,9 @@ internal abstract class DispatchedTask( try { val delegate = delegate as DispatchedContinuation val continuation = delegate.continuation - val context = continuation.context - val state = takeState() // NOTE: Must take state in any case, even if cancelled - withCoroutineContext(context, delegate.countOrElement) { + withContinuationContext(continuation, delegate.countOrElement) { + val context = continuation.context + val state = takeState() // NOTE: Must take state in any case, even if cancelled val exception = getExceptionalResult(state) /* * Check whether continuation was originally resumed with an exception. diff --git a/kotlinx-coroutines-core/js/src/CoroutineContext.kt b/kotlinx-coroutines-core/js/src/CoroutineContext.kt index c0b0c511f9..aed0327700 100644 --- a/kotlinx-coroutines-core/js/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/js/src/CoroutineContext.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* import kotlin.browser.* import kotlin.coroutines.* @@ -49,5 +50,13 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): // No debugging facilities on JS internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on JS + +internal actual class UndispatchedCoroutine actual constructor( + context: CoroutineContext, + uCont: Continuation +) : ScopeCoroutine(context, uCont) { + override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) +} diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index 5a69d48aac..3796c91c11 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -2,12 +2,14 @@ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ +@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER") + package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlinx.coroutines.scheduling.* -import java.util.concurrent.atomic.* import kotlin.coroutines.* +import kotlin.coroutines.jvm.internal.* internal const val COROUTINES_SCHEDULER_PROPERTY_NAME = "kotlinx.coroutines.scheduler" @@ -48,6 +50,72 @@ internal actual inline fun withCoroutineContext(context: CoroutineContext, c } } +/** + * Executes a block using a context of a given continuation. + */ +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.undispatchedCompletion() + } else + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + undispatchedCompletion?.saveThreadContext(context, oldValue) + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) + restoreThreadContext(context, oldValue) + } +} + +internal tailrec fun Continuation<*>.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: Continuation<*> = when (this) { + is BaseContinuationImpl -> completion ?: return null // regular suspending function -- direct resume + is DispatchedCoroutine -> return null // dispatches on resume + is ScopeCoroutine -> uCont // other scoped coroutine -- direct resume + else -> return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +// Used by withContext when context changes, but dispatcher stays the same +internal actual class UndispatchedCoroutine actual constructor( + context: CoroutineContext, + uCont: Continuation +) : ScopeCoroutine(context, uCont) { + private var savedContext: CoroutineContext? = null + private var savedOldValue: Any? = null + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + savedContext = context + savedOldValue = oldValue + } + + fun clearThreadContext(): Boolean { + if (savedContext == null) return false + savedContext = null + savedOldValue = null + return true + } + + override fun afterResume(state: Any?) { + savedContext?.let { context -> + restoreThreadContext(context, savedOldValue) + savedContext = null + savedOldValue = null + } + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } +} + internal actual val CoroutineContext.coroutineName: String? get() { if (!DEBUG) return null val coroutineId = this[CoroutineId] ?: return null diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 9d9d30e41d..18c2ce0459 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -7,8 +7,8 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* - -private val ZERO = Symbol("ZERO") +@JvmField +internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") // Used when there are >= 2 active elements in the context private class ThreadState(val context: CoroutineContext, n: Int) { @@ -60,12 +60,13 @@ private val restoreState = internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! // countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { @Suppress("NAME_SHADOWING") val countOrElement = countOrElement ?: threadContextElements(context) @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") return when { - countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements + countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements // ^^^ identity comparison for speed, we know zero always has the same identity countOrElement is Int -> { // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values @@ -82,7 +83,7 @@ internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any? internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { when { - oldState === ZERO -> return // very fast path when there are no ThreadContextElements + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements oldState is ThreadState -> { // slow path with multiple stored ThreadContextElements oldState.start() diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementRestoreTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementRestoreTest.kt new file mode 100644 index 0000000000..560aa4c47d --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementRestoreTest.kt @@ -0,0 +1,198 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import org.junit.Test +import kotlin.coroutines.* +import kotlin.test.* + +class ThreadContextElementRestoreTest : TestBase() { + private val tl = ThreadLocal() + + // Checks that ThreadLocal context is properly restored after executing the given block inside + // withContext(tl.asContextElement("OK")) code running in different outer contexts + private inline fun check(crossinline block: suspend () -> Unit) = runTest { + val mainDispatcher = coroutineContext[ContinuationInterceptor] as CoroutineDispatcher + // Scenario #1: withContext(ThreadLocal) direct from runTest + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + // Scenario #2: withContext(ThreadLocal) from coroutineScope + coroutineScope { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #3: withContext(ThreadLocal) from undispatched withContext + withContext(CoroutineName("NAME")) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #4: withContext(ThreadLocal) from dispatched withContext + withContext(wrapperDispatcher()) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #5: withContext(ThreadLocal) from withContext(ThreadLocal) + withContext(tl.asContextElement(null)) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #6: withContext(ThreadLocal) from withTimeout + withTimeout(1000) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #7: withContext(ThreadLocal) from withContext(Unconfined) + withContext(Dispatchers.Unconfined) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #8: withContext(ThreadLocal) from withContext(Default) + withContext(Dispatchers.Default) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + // Scenario #9: withContext(ThreadLocal) from withContext(mainDispatcher) + withContext(mainDispatcher) { + withContext(tl.asContextElement("OK")) { + block() + assertEquals("OK", tl.get()) + } + assertEquals(null, tl.get()) + } + } + + @Test + fun testSimpleNoSuspend() = + check {} + + @Test + fun testSimpleDelay() = check { + delay(1) + } + + @Test + fun testSimpleYield() = check { + yield() + } + + private suspend fun deepDelay() { + deepDelay2(); deepDelay2() + } + + private suspend fun deepDelay2() { + delay(1); delay(1) + } + + @Test + fun testDeepDelay() = check { + deepDelay() + } + + private suspend fun deepYield() { + deepYield2(); deepYield2() + } + + private suspend fun deepYield2() { + yield(); yield() + } + + @Test + fun testDeepYield() = check { + deepYield() + } + + @Test + fun testCoroutineScopeDelay() = check { + coroutineScope { + delay(1) + } + } + + @Test + fun testCoroutineScopeYield() = check { + coroutineScope { + yield() + } + } + + @Test + fun testWithContextUndispatchedDelay() = check { + withContext(CoroutineName("INNER")) { + delay(1) + } + } + + @Test + fun testWithContextUndispatchedYield() = check { + withContext(CoroutineName("INNER")) { + yield() + } + } + + @Test + fun testWithContextDispatchedDelay() = check { + withContext(wrapperDispatcher()) { + delay(1) + } + } + + @Test + fun testWithContextDispatchedYield() = check { + withContext(wrapperDispatcher()) { + yield() + } + } + + @Test + fun testWithTimeoutDelay() = check { + withTimeout(1000) { + delay(1) + } + } + + @Test + fun testWithTimeoutYield() = check { + withTimeout(1000) { + yield() + } + } + + @Test + fun testWithUnconfinedContextDelay() = check { + withContext(Dispatchers.Unconfined) { + delay(1) + } + } + @Test + fun testWithUnconfinedContextYield() = check { + withContext(Dispatchers.Unconfined) { + yield() + } + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index 4ec1289ee7..86ffa8dec9 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.native.concurrent.* @@ -38,5 +39,13 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): // No debugging facilities on native internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on native + +internal actual class UndispatchedCoroutine actual constructor( + context: CoroutineContext, + uCont: Continuation +) : ScopeCoroutine(context, uCont) { + override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) +} \ No newline at end of file