From a5dd74b2b325113586768b8b61af6e2833a139c2 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Mon, 4 Apr 2022 03:57:56 -0700 Subject: [PATCH] CopyableThreadContextElement implementation (#3227) New approach eagerly copies corresponding elements to avoid accidental top-level reuse and also provides merge capability in case when an element is being overwritten. Merge capability is crucial in tracing scenarios to properly preserve the state of linked thread locals Co-authored-by: dkhalanskyjb <52952525+dkhalanskyjb@users.noreply.github.com> --- .../api/kotlinx-coroutines-core.api | 4 +- .../common/src/Builders.common.kt | 3 +- .../common/src/CoroutineContext.common.kt | 12 +- .../js/src/CoroutineContext.kt | 4 + .../jvm/src/CoroutineContext.kt | 87 +++++++++--- .../jvm/src/ThreadContextElement.kt | 33 ++++- .../jvm/test/ThreadContextElementTest.kt | 13 +- .../test/ThreadContextMutableCopiesTest.kt | 134 ++++++++++++++++++ .../native/src/CoroutineContext.kt | 4 + 9 files changed, 259 insertions(+), 35 deletions(-) create mode 100644 kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index d1fc624a5e..79f3cf4308 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -141,7 +141,8 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run } public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement { - public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement; + public abstract fun copyForChild ()Lkotlinx/coroutines/CopyableThreadContextElement; + public abstract fun mergeForChild (Lkotlin/coroutines/CoroutineContext$Element;)Lkotlin/coroutines/CoroutineContext; } public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls { @@ -156,6 +157,7 @@ public abstract interface class kotlinx/coroutines/CopyableThrowable { } public final class kotlinx/coroutines/CoroutineContextKt { + public static final fun newCoroutineContext (Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; public static final fun newCoroutineContext (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/kotlinx-coroutines-core/common/src/Builders.common.kt b/kotlinx-coroutines-core/common/src/Builders.common.kt index a11ffe9eb4..c360724245 100644 --- a/kotlinx-coroutines-core/common/src/Builders.common.kt +++ b/kotlinx-coroutines-core/common/src/Builders.common.kt @@ -148,7 +148,8 @@ public suspend fun withContext( return suspendCoroutineUninterceptedOrReturn sc@ { uCont -> // compute new context val oldContext = uCont.context - val newContext = oldContext + context + // Copy CopyableThreadContextElement if necessary + val newContext = oldContext.newCoroutineContext(context) // always check for cancellation of new context newContext.ensureActive() // FAST PATH #1 -- new context is the same as the old one diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt index da094e152d..9153f39821 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt @@ -7,11 +7,19 @@ package kotlinx.coroutines import kotlin.coroutines.* /** - * Creates a context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher or - * [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on). + * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or + * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. */ public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext +/** + * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. + * @suppress + */ +@InternalCoroutinesApi +public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext + @PublishedApi @Suppress("PropertyName") internal expect val DefaultDelay: Delay diff --git a/kotlinx-coroutines-core/js/src/CoroutineContext.kt b/kotlinx-coroutines-core/js/src/CoroutineContext.kt index 95cb3c2964..8036c88a10 100644 --- a/kotlinx-coroutines-core/js/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/js/src/CoroutineContext.kt @@ -42,6 +42,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): combined + Dispatchers.Default else combined } +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + return this + addedContext +} + // No debugging facilities on JS internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index 6291ea2b97..e08b805295 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -9,36 +9,83 @@ import kotlin.coroutines.* import kotlin.coroutines.jvm.internal.CoroutineStackFrame /** - * Creates context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher nor - * [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on). - * + * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or + * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. */ @ExperimentalCoroutinesApi public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = coroutineContext.foldCopiesForChildCoroutine() + context + val combined = foldCopies(coroutineContext, context, true) val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) debug + Dispatchers.Default else debug } /** - * Returns the [CoroutineContext] for a child coroutine to inherit. - * - * If any [CopyableThreadContextElement] is in the [this], calls - * [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext] - * by folding the returned copied elements into [this]. - * - * Returns [this] if `this` has zero [CopyableThreadContextElement] in it. + * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. + * @suppress */ -private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext { - val hasToCopy = fold(false) { result, it -> - result || it is CopyableThreadContextElement<*> +@InternalCoroutinesApi +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + /* + * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements) + * contains copyable elements. + */ + if (!addedContext.hasCopyableElements()) return this + addedContext + return foldCopies(this, addedContext, false) +} + +private fun CoroutineContext.hasCopyableElements(): Boolean = + fold(false) { result, it -> result || it is CopyableThreadContextElement<*> } + +/** + * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary. + * The rules are the following: + * * If neither context has CTCE, the sum of two contexts is returned + * * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context + * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`. + * * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild] + * * Every CTCE from the right-hand side context that hasn't been merged is copied + * * Everything else is added to the resulting context as is. + */ +private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext { + // Do we have something to copy left-hand side? + val hasElementsLeft = originalContext.hasCopyableElements() + val hasElementsRight = appendContext.hasCopyableElements() + + // Nothing to fold, so just return the sum of contexts + if (!hasElementsLeft && !hasElementsRight) { + return originalContext + appendContext + } + + var leftoverContext = appendContext + val folded = originalContext.fold(EmptyCoroutineContext) { result, element -> + if (element !is CopyableThreadContextElement<*>) return@fold result + element + // Will this element be overwritten? + val newElement = leftoverContext[element.key] + // No, just copy it + if (newElement == null) { + // For 'withContext'-like builders we do not copy as the element is not shared + return@fold result + if (isNewCoroutine) element.copyForChild() else element + } + // Yes, then first remove the element from append context + leftoverContext = leftoverContext.minusKey(element.key) + // Return the sum + @Suppress("UNCHECKED_CAST") + return@fold result + (element as CopyableThreadContextElement).mergeForChild(newElement) } - if (!hasToCopy) return this - return fold(EmptyCoroutineContext) { combined, it -> - combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it + + if (hasElementsRight) { + leftoverContext = leftoverContext.fold(EmptyCoroutineContext) { result, element -> + // We're appending new context element -- we have to copy it, otherwise it may be shared with others + if (element is CopyableThreadContextElement<*>) { + return@fold result + element.copyForChild() + } + return@fold result + element + } } + return folded + leftoverContext } /** @@ -77,7 +124,7 @@ internal actual inline fun withContinuationContext(continuation: Continuatio internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { if (this !is CoroutineStackFrame) return null /* - * Fast-path to detect whether we have unispatched coroutine at all in our stack. + * Fast-path to detect whether we have undispatched coroutine at all in our stack. * * Implementation note. * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: @@ -88,8 +135,8 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont * Both options should work, but it requires more careful studying of the performance * and, mostly, maintainability impact. */ - val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null - if (!potentiallyHasUndispatchedCorotuine) return null + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null val completion = undispatchedCompletion() completion?.saveThreadContext(context, oldValue) return completion diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index 1a960699c7..d2b6b6b988 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -80,7 +80,7 @@ public interface ThreadContextElement : CoroutineContext.Element { /** * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. * - * When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement] + * When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement] * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. * * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine @@ -99,6 +99,7 @@ public interface ThreadContextElement : CoroutineContext.Element { * ``` * class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement { * companion object Key : CoroutineContext.Key + * * override val key: CoroutineContext.Key = Key * * override fun updateThreadContext(context: CoroutineContext): TraceData? { @@ -111,24 +112,35 @@ public interface ThreadContextElement : CoroutineContext.Element { * traceThreadLocal.set(oldState) * } * - * override fun copyForChildCoroutine(): CopyableThreadContextElement { + * override fun copyForChild(): TraceContextElement { * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes * // ThreadLocal writes between resumption of the parent coroutine and the launch of the * // child coroutine visible to the child. * return TraceContextElement(traceThreadLocal.get()?.copy()) * } + * + * override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext { + * // Merge operation defines how to handle situations when both + * // the parent coroutine has an element in the context and + * // an element with the same key was also + * // explicitly passed to the child coroutine. + * // If merging does not require special behavior, + * // the copy of the element can be returned. + * return TraceContextElement(traceThreadLocal.get()?.copy()) + * } * } * ``` * - * A coroutine using this mechanism can safely call Java code that assumes it's called using a - * `Thread`. + * A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's + * value is installed into the target thread local. */ +@DelicateCoroutinesApi @ExperimentalCoroutinesApi public interface CopyableThreadContextElement : ThreadContextElement { /** * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child - * coroutine's context that is under construction. + * coroutine's context that is under construction if the added context does not contain an element with the same [key]. * * This function is called on the element each time a new coroutine inherits a context containing it, * and the returned value is folded into the context given to the child. @@ -136,7 +148,16 @@ public interface CopyableThreadContextElement : ThreadContextElement { * Since this method is called whenever a new coroutine is launched in a context containing this * [CopyableThreadContextElement], implementations are performance-sensitive. */ - public fun copyForChildCoroutine(): CopyableThreadContextElement + public fun copyForChild(): CopyableThreadContextElement + + /** + * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * coroutine's context that is under construction if the added context does contain an element with the same [key]. + * + * This method is invoked on the original element, accepting as the parameter + * the element that is supposed to overwrite it. + */ + public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext } /** diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index baba4aa8e6..ec45406bce 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -126,8 +126,7 @@ class ThreadContextElementTest : TestBase() { @Test fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest { newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { - val startData = MyData() - withContext(it + CopyForChildCoroutineElement(startData)) { + withContext(it + CopyForChildCoroutineElement(MyData())) { val forBlockData = MyData() myThreadLocal.setForBlock(forBlockData) { assertSame(myThreadLocal.get(), forBlockData) @@ -153,7 +152,7 @@ class ThreadContextElementTest : TestBase() { assertSame(myThreadLocal.get(), forBlockData) } } - assertSame(myThreadLocal.get(), startData) // Asserts value was restored. + assertNull(myThreadLocal.get()) // Asserts value was restored to its origin } } } @@ -187,7 +186,7 @@ class MyElement(val data: MyData) : ThreadContextElement { } /** - * A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine]. + * A [ThreadContextElement] that implements copy semantics in [copyForChild]. */ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { companion object Key : CoroutineContext.Key @@ -201,6 +200,10 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle return oldState } + override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement { + TODO("Not used in tests") + } + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { myThreadLocal.set(oldState) } @@ -216,7 +219,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the * thread and calls [restoreThreadContext]. */ - override fun copyForChildCoroutine(): CopyableThreadContextElement { + override fun copyForChild(): CopyForChildCoroutineElement { return CopyForChildCoroutineElement(myThreadLocal.get()) } } diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt new file mode 100644 index 0000000000..34e5955fd7 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt @@ -0,0 +1,134 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.coroutines.* +import kotlin.test.* + +class ThreadContextMutableCopiesTest : TestBase() { + companion object { + val threadLocalData: ThreadLocal> = ThreadLocal.withInitial { ArrayList() } + } + + class MyMutableElement( + val mutableData: MutableList + ) : CopyableThreadContextElement> { + + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key<*> + get() = Key + + override fun updateThreadContext(context: CoroutineContext): MutableList { + val st = threadLocalData.get() + threadLocalData.set(mutableData) + return st + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: MutableList) { + threadLocalData.set(oldState) + } + + override fun copyForChild(): MyMutableElement { + return MyMutableElement(ArrayList(mutableData)) + } + + override fun mergeForChild(overwritingElement: CoroutineContext.Element): MyMutableElement { + overwritingElement as MyMutableElement // <- app-specific, may be another subtype + return MyMutableElement((mutableData.toSet() + overwritingElement.mutableData).toMutableList()) + } + } + + @Test + fun testDataIsCopied() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + val data = threadLocalData.get() + expect(1) + launch(root) { + assertNotSame(data, threadLocalData.get()) + assertEquals(data, threadLocalData.get()) + finish(2) + } + } + } + + @Test + fun testDataIsNotOverwritten() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + expect(1) + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + launch(Dispatchers.Default + root) { + assertEquals(listOf("X", "Y"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsMerged() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + expect(1) + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + launch(Dispatchers.Default + MyMutableElement(mutableListOf("Z"))) { + assertEquals(listOf("X", "Y", "Z"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsNotOverwrittenWithContext() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + expect(1) + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + withContext(Dispatchers.Default + root) { + assertEquals(listOf("X", "Y"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsCopiedForRunBlocking() = runTest { + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + runBlocking(root) { + assertNotSame(originalData, threadLocalData.get()) + } + } + + @Test + fun testDataIsCopiedForCoroutine() = runTest { + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + expect(1) + launch(root) { + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } +} diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index e1e29581a7..6e2dac1a29 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -49,6 +49,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): combined + (DefaultDelay as CoroutineContext.Element) else combined } +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + return this + addedContext +} + // No debugging facilities on native internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()