From 603bd79c96d2473e54cb4a9043b30102554dd7e9 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Wed, 17 Nov 2021 11:59:17 +0300 Subject: [PATCH] Implemented `CopyableThreadContextElement` with a `copyForChildCoroutine()`. (#3025) * This is a `ThreadContextElement` that is copy-constructed when a new coroutine is created and inherits the context. Co-authored-by: Tyson Henning Fixes #2839 --- .../api/kotlinx-coroutines-core.api | 11 ++ .../jvm/src/CoroutineContext.kt | 21 ++- .../jvm/src/ThreadContextElement.kt | 63 +++++++++ .../jvm/test/ThreadContextElementTest.kt | 130 +++++++++++++++++- 4 files changed, 223 insertions(+), 2 deletions(-) diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 9ccef08308..36a516e070 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -140,6 +140,17 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run public fun (Ljava/lang/String;Ljava/lang/Throwable;)V } +public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement { + public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement; +} + +public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls { + public static fun fold (Lkotlinx/coroutines/CopyableThreadContextElement;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; + public static fun get (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element; + public static fun minusKey (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext; + public static fun plus (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; +} + public abstract interface class kotlinx/coroutines/CopyableThrowable { public abstract fun createCopy ()Ljava/lang/Throwable; } diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index 9a8f168bcd..d562207f8b 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -16,12 +16,31 @@ import kotlin.coroutines.jvm.internal.CoroutineStackFrame */ @ExperimentalCoroutinesApi public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = coroutineContext + context + val combined = coroutineContext.foldCopiesForChildCoroutine() + context val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) debug + Dispatchers.Default else debug } +/** + * Returns the [CoroutineContext] for a child coroutine to inherit. + * + * If any [CopyableThreadContextElement] is in the [this], calls + * [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext] + * by folding the returned copied elements into [this]. + * + * Returns [this] if `this` has zero [CopyableThreadContextElement] in it. + */ +private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext { + val hasToCopy = fold(false) { result, it -> + result || it is CopyableThreadContextElement<*> + } + if (!hasToCopy) return this + return fold(EmptyCoroutineContext) { combined, it -> + combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it + } +} + /** * Executes a block using a given coroutine context. */ diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index 37fd70a23e..1b825cef01 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -77,6 +77,69 @@ public interface ThreadContextElement : CoroutineContext.Element { public fun restoreThreadContext(context: CoroutineContext, oldState: S) } +/** + * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. + * + * When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement] + * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. + * + * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine + * will be visible to _itself_ and any child coroutine launched _after_ that write. + * + * Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen + * to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_ + * launching a child coroutine will not be visible to that child coroutine. + * + * This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and + * correctly, regardless of the coroutine's structured concurrency. + * + * This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace + * is in a coroutine: + * + * ``` + * class TraceContextElement(val traceData: TraceData?) : CopyableThreadContextElement { + * companion object Key : CoroutineContext.Key + * override val key: CoroutineContext.Key + * get() = Key + * + * override fun updateThreadContext(context: CoroutineContext): TraceData? { + * val oldState = traceThreadLocal.get() + * traceThreadLocal.set(data) + * return oldState + * } + * + * override fun restoreThreadContext(context: CoroutineContext, oldData: TraceData?) { + * traceThreadLocal.set(oldState) + * } + * + * override fun copyForChildCoroutine(): CopyableThreadContextElement { + * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes + * // ThreadLocal writes between resumption of the parent coroutine and the launch of the + * // child coroutine visible to the child. + * return CopyForChildCoroutineElement(traceThreadLocal.get()) + * } + * } + * ``` + * + * A coroutine using this mechanism can safely call Java code that assumes it's called using a + * `Thread`. + */ +@ExperimentalCoroutinesApi +public interface CopyableThreadContextElement : ThreadContextElement { + + /** + * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * coroutine's context that is under construction. + * + * This function is called on the element each time a new coroutine inherits a context containing it, + * and the returned value is folded into the context given to the child. + * + * Since this method is called whenever a new coroutine is launched in a context containing this + * [CopyableThreadContextElement], implementations are performance-sensitive. + */ + public fun copyForChildCoroutine(): CopyableThreadContextElement +} + /** * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement] * maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on. diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index ea43c7ade2..baba4aa8e6 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -54,7 +54,6 @@ class ThreadContextElementTest : TestBase() { assertNull(myThreadLocal.get()) } - @Test fun testWithContext() = runTest { expect(1) @@ -86,6 +85,78 @@ class ThreadContextElementTest : TestBase() { finish(7) } + + @Test + fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest { + var parentElement: MyElement? = null + var inheritedElement: MyElement? = null + + newSingleThreadContext("withContext").use { + withContext(it + MyElement(MyData())) { + parentElement = coroutineContext[MyElement.Key] + launch { + inheritedElement = coroutineContext[MyElement.Key] + } + } + } + + assertSame(inheritedElement, parentElement, + "Inner and outer coroutines did not have the same object reference to a" + + " ThreadContextElement that did not override `copyForChildCoroutine()`") + } + + @Test + fun testCopyableElementCopiedOnLaunch() = runTest { + var parentElement: CopyForChildCoroutineElement? = null + var inheritedElement: CopyForChildCoroutineElement? = null + + newSingleThreadContext("withContext").use { + withContext(it + CopyForChildCoroutineElement(MyData())) { + parentElement = coroutineContext[CopyForChildCoroutineElement.Key] + launch { + inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key] + } + } + } + + assertNotSame(inheritedElement, parentElement, + "Inner coroutine did not copy its copyable ThreadContextElement.") + } + + @Test + fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest { + newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { + val startData = MyData() + withContext(it + CopyForChildCoroutineElement(startData)) { + val forBlockData = MyData() + myThreadLocal.setForBlock(forBlockData) { + assertSame(myThreadLocal.get(), forBlockData) + launch { + assertSame(myThreadLocal.get(), forBlockData) + } + launch { + assertSame(myThreadLocal.get(), forBlockData) + // Modify value in child coroutine. Writes to the ThreadLocal and + // the (copied) ThreadLocalElement's memory are not visible to peer or + // ancestor coroutines, so this write is both threadsafe and coroutinesafe. + val innerCoroutineData = MyData() + myThreadLocal.setForBlock(innerCoroutineData) { + assertSame(myThreadLocal.get(), innerCoroutineData) + } + assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored. + } + launch { + val innerCoroutineData = MyData() + myThreadLocal.setForBlock(innerCoroutineData) { + assertSame(myThreadLocal.get(), innerCoroutineData) + } + assertSame(myThreadLocal.get(), forBlockData) + } + } + assertSame(myThreadLocal.get(), startData) // Asserts value was restored. + } + } + } } class MyData @@ -114,3 +185,60 @@ class MyElement(val data: MyData) : ThreadContextElement { myThreadLocal.set(oldState) } } + +/** + * A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine]. + */ +class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key + get() = Key + + override fun updateThreadContext(context: CoroutineContext): MyData? { + val oldState = myThreadLocal.get() + myThreadLocal.set(data) + return oldState + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + myThreadLocal.set(oldState) + } + + /** + * At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new + * child coroutine, and that value is copied to a new, unique, ThreadContextElement memory + * reference for the child coroutine to use uniquely. + * + * n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not + * the value initially passed to the ThreadContextElement in order to reflect writes made to the + * ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes + * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the + * thread and calls [restoreThreadContext]. + */ + override fun copyForChildCoroutine(): CopyableThreadContextElement { + return CopyForChildCoroutineElement(myThreadLocal.get()) + } +} + +/** + * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block]. + * + * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a + * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically + * at every statement reached, whether that statement is reached immediately, across suspend and + * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal` + * by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal` + * by the parent coroutine _after_ launching a child coroutine will not be visible to that child + * coroutine. + */ +private inline fun ThreadLocal.setForBlock( + value: ThreadLocalT, + crossinline block: () -> OutputT +) { + val priorValue = get() + set(value) + block() + set(priorValue) +} +