diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index d562207f8b..6291ea2b97 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -107,7 +107,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC /** * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. - * Used as a performance optimization to avoid stack walking where it is not nesessary. + * Used as a performance optimization to avoid stack walking where it is not necessary. */ private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { override val key: CoroutineContext.Key<*> @@ -120,26 +120,34 @@ internal actual class UndispatchedCoroutineactual constructor ( uCont: Continuation ) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { - private var savedContext: CoroutineContext? = null - private var savedOldValue: Any? = null + /* + * The state is thread-local because this coroutine can be used concurrently. + * Scenario of usage (withContinuationContext): + * val state = saveThreadContext(ctx) + * try { + * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called + * // COROUTINE_SUSPENDED is returned + * } finally { + * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread + * // and it also calls saveThreadContext and clearThreadContext + * } + */ + private var threadStateToRecover = ThreadLocal>() fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { - savedContext = context - savedOldValue = oldValue + threadStateToRecover.set(context to oldValue) } fun clearThreadContext(): Boolean { - if (savedContext == null) return false - savedContext = null - savedOldValue = null + if (threadStateToRecover.get() == null) return false + threadStateToRecover.set(null) return true } override fun afterResume(state: Any?) { - savedContext?.let { context -> - restoreThreadContext(context, savedOldValue) - savedContext = null - savedOldValue = null + threadStateToRecover.get()?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + threadStateToRecover.set(null) } // resume undispatched -- update context but stay on the same dispatcher val result = recoverResult(state, uCont) diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt new file mode 100644 index 0000000000..f9941d0215 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.test.* + + +class ThreadLocalStressTest : TestBase() { + + private val threadLocal = ThreadLocal() + + // See the comment in doStress for the machinery + @Test + fun testStress() = runTest { + repeat (100 * stressTestMultiplierSqrt) { + withContext(Dispatchers.Default) { + repeat(100) { + launch { + doStress(null) + } + } + } + } + } + + @Test + fun testStressWithOuterValue() = runTest { + repeat (100 * stressTestMultiplierSqrt) { + withContext(Dispatchers.Default + threadLocal.asContextElement("bar")) { + repeat(100) { + launch { + doStress("bar") + } + } + } + } + } + + private suspend fun doStress(expectedValue: String?) { + assertEquals(expectedValue, threadLocal.get()) + try { + /* + * Here we are using very specific code-path to trigger the execution we want to. + * The bug, in general, has a larger impact, but this particular code pinpoints it: + * + * 1) We use _undispatched_ withContext with thread element + * 2) We cancel the coroutine + * 3) We use 'suspendCancellableCoroutineReusable' that does _postponed_ cancellation check + * which makes the reproduction of this race pretty reliable. + * + * Now the following code path is likely to be triggered: + * + * T1 from within 'withContinuationContext' method: + * Finds 'oldValue', finds undispatched completion, invokes its 'block' argument. + * 'block' is this coroutine, it goes to 'trySuspend', checks for postponed cancellation and *dispatches* it. + * The execution stops _right_ before 'undispatchedCompletion.clearThreadContext()'. + * + * T2 now executes the dispatched cancellation and concurrently mutates the state of the undispatched completion. + * All bets are off, now both threads can leave the thread locals state inconsistent. + */ + withContext(threadLocal.asContextElement("foo")) { + yield() + cancel() + suspendCancellableCoroutineReusable { } + } + } finally { + assertEquals(expectedValue, threadLocal.get()) + } + } +}