From ac276a1d33aa041067fb51102d70f45615d52610 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Tue, 1 Feb 2022 12:29:11 +0300 Subject: [PATCH] =?UTF-8?q?Confine=20context-specific=20state=20to=20the?= =?UTF-8?q?=20thread=20in=20UndispatchedCoroutine=E2=80=A6=20(#3155)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Confine context-specific state to the thread in UndispatchedCoroutine in order to avoid state interference when the coroutine is updated concurrently. Concurrency is inevitable in this scenario: when the coroutine that has UndispatchedCoroutine as its completion suspends, we have to clear the thread context, but while we are doing so, concurrent resume of the coroutine could've happened that also ends up in save/clear/update context Fixes #2930 --- .../jvm/src/CoroutineContext.kt | 32 +++++---- .../jvm/test/ThreadLocalStressTest.kt | 72 +++++++++++++++++++ 2 files changed, 92 insertions(+), 12 deletions(-) create mode 100644 kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index d562207f8b..6291ea2b97 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -107,7 +107,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC /** * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. - * Used as a performance optimization to avoid stack walking where it is not nesessary. + * Used as a performance optimization to avoid stack walking where it is not necessary. */ private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { override val key: CoroutineContext.Key<*> @@ -120,26 +120,34 @@ internal actual class UndispatchedCoroutineactual constructor ( uCont: Continuation ) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { - private var savedContext: CoroutineContext? = null - private var savedOldValue: Any? = null + /* + * The state is thread-local because this coroutine can be used concurrently. + * Scenario of usage (withContinuationContext): + * val state = saveThreadContext(ctx) + * try { + * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called + * // COROUTINE_SUSPENDED is returned + * } finally { + * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread + * // and it also calls saveThreadContext and clearThreadContext + * } + */ + private var threadStateToRecover = ThreadLocal>() fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { - savedContext = context - savedOldValue = oldValue + threadStateToRecover.set(context to oldValue) } fun clearThreadContext(): Boolean { - if (savedContext == null) return false - savedContext = null - savedOldValue = null + if (threadStateToRecover.get() == null) return false + threadStateToRecover.set(null) return true } override fun afterResume(state: Any?) { - savedContext?.let { context -> - restoreThreadContext(context, savedOldValue) - savedContext = null - savedOldValue = null + threadStateToRecover.get()?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + threadStateToRecover.set(null) } // resume undispatched -- update context but stay on the same dispatcher val result = recoverResult(state, uCont) diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt new file mode 100644 index 0000000000..f9941d0215 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.test.* + + +class ThreadLocalStressTest : TestBase() { + + private val threadLocal = ThreadLocal() + + // See the comment in doStress for the machinery + @Test + fun testStress() = runTest { + repeat (100 * stressTestMultiplierSqrt) { + withContext(Dispatchers.Default) { + repeat(100) { + launch { + doStress(null) + } + } + } + } + } + + @Test + fun testStressWithOuterValue() = runTest { + repeat (100 * stressTestMultiplierSqrt) { + withContext(Dispatchers.Default + threadLocal.asContextElement("bar")) { + repeat(100) { + launch { + doStress("bar") + } + } + } + } + } + + private suspend fun doStress(expectedValue: String?) { + assertEquals(expectedValue, threadLocal.get()) + try { + /* + * Here we are using very specific code-path to trigger the execution we want to. + * The bug, in general, has a larger impact, but this particular code pinpoints it: + * + * 1) We use _undispatched_ withContext with thread element + * 2) We cancel the coroutine + * 3) We use 'suspendCancellableCoroutineReusable' that does _postponed_ cancellation check + * which makes the reproduction of this race pretty reliable. + * + * Now the following code path is likely to be triggered: + * + * T1 from within 'withContinuationContext' method: + * Finds 'oldValue', finds undispatched completion, invokes its 'block' argument. + * 'block' is this coroutine, it goes to 'trySuspend', checks for postponed cancellation and *dispatches* it. + * The execution stops _right_ before 'undispatchedCompletion.clearThreadContext()'. + * + * T2 now executes the dispatched cancellation and concurrently mutates the state of the undispatched completion. + * All bets are off, now both threads can leave the thread locals state inconsistent. + */ + withContext(threadLocal.asContextElement("foo")) { + yield() + cancel() + suspendCancellableCoroutineReusable { } + } + } finally { + assertEquals(expectedValue, threadLocal.get()) + } + } +}