diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index d1fc624a5e..79f3cf4308 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -141,7 +141,8 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run } public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement { - public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement; + public abstract fun copyForChild ()Lkotlinx/coroutines/CopyableThreadContextElement; + public abstract fun mergeForChild (Lkotlin/coroutines/CoroutineContext$Element;)Lkotlin/coroutines/CoroutineContext; } public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls { @@ -156,6 +157,7 @@ public abstract interface class kotlinx/coroutines/CopyableThrowable { } public final class kotlinx/coroutines/CoroutineContextKt { + public static final fun newCoroutineContext (Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; public static final fun newCoroutineContext (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/kotlinx-coroutines-core/common/src/Builders.common.kt b/kotlinx-coroutines-core/common/src/Builders.common.kt index a11ffe9eb4..c360724245 100644 --- a/kotlinx-coroutines-core/common/src/Builders.common.kt +++ b/kotlinx-coroutines-core/common/src/Builders.common.kt @@ -148,7 +148,8 @@ public suspend fun withContext( return suspendCoroutineUninterceptedOrReturn sc@ { uCont -> // compute new context val oldContext = uCont.context - val newContext = oldContext + context + // Copy CopyableThreadContextElement if necessary + val newContext = oldContext.newCoroutineContext(context) // always check for cancellation of new context newContext.ensureActive() // FAST PATH #1 -- new context is the same as the old one diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt index da094e152d..9153f39821 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt @@ -7,11 +7,19 @@ package kotlinx.coroutines import kotlin.coroutines.* /** - * Creates a context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher or - * [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on). + * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or + * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. */ public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext +/** + * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. + * @suppress + */ +@InternalCoroutinesApi +public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext + @PublishedApi @Suppress("PropertyName") internal expect val DefaultDelay: Delay diff --git a/kotlinx-coroutines-core/js/src/CoroutineContext.kt b/kotlinx-coroutines-core/js/src/CoroutineContext.kt index 95cb3c2964..8036c88a10 100644 --- a/kotlinx-coroutines-core/js/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/js/src/CoroutineContext.kt @@ -42,6 +42,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): combined + Dispatchers.Default else combined } +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + return this + addedContext +} + // No debugging facilities on JS internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index d562207f8b..f1c4f0794c 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -9,36 +9,83 @@ import kotlin.coroutines.* import kotlin.coroutines.jvm.internal.CoroutineStackFrame /** - * Creates context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher nor - * [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on). - * + * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or + * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. */ @ExperimentalCoroutinesApi public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = coroutineContext.foldCopiesForChildCoroutine() + context + val combined = foldCopies(coroutineContext, context, true) val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) debug + Dispatchers.Default else debug } /** - * Returns the [CoroutineContext] for a child coroutine to inherit. - * - * If any [CopyableThreadContextElement] is in the [this], calls - * [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext] - * by folding the returned copied elements into [this]. - * - * Returns [this] if `this` has zero [CopyableThreadContextElement] in it. + * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. + * @suppress */ -private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext { - val hasToCopy = fold(false) { result, it -> - result || it is CopyableThreadContextElement<*> +@InternalCoroutinesApi +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + /* + * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements) + * contains copyable elements. + */ + if (!addedContext.hasCopyableElements()) return this + addedContext + return foldCopies(this, addedContext, false) +} + +private fun CoroutineContext.hasCopyableElements(): Boolean = + fold(false) { result, it -> result || it is CopyableThreadContextElement<*> } + +/** + * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary. + * The rules are the following: + * * If neither context has CTCE, the sum of two contexts is returned + * * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context + * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`. + * * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild] + * * Every CTCE from the right-hand side context that hasn't been merged is copied + * * Everything else is added to the resulting context as is. + */ +private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext { + // Do we have something to copy left-hand side? + val hasElementsLeft = originalContext.hasCopyableElements() + val hasElementsRight = appendContext.hasCopyableElements() + + // Nothing to fold, so just return the sum of contexts + if (!hasElementsLeft && !hasElementsRight) { + return originalContext + appendContext + } + + var leftoverContext = appendContext + val folded = originalContext.fold(EmptyCoroutineContext) { result, element -> + if (element !is CopyableThreadContextElement<*>) return@fold result + element + // Will this element be overwritten? + val newElement = leftoverContext[element.key] + // No, just copy it + if (newElement == null) { + // For 'withContext'-like builders we do not copy as the element is not shared + return@fold result + if (isNewCoroutine) element.copyForChild() else element + } + // Yes, then first remove the element from append context + leftoverContext = leftoverContext.minusKey(element.key) + // Return the sum + @Suppress("UNCHECKED_CAST") + return@fold result + (element as CopyableThreadContextElement).mergeForChild(newElement) } - if (!hasToCopy) return this - return fold(EmptyCoroutineContext) { combined, it -> - combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it + + if (hasElementsRight) { + leftoverContext = leftoverContext.fold(EmptyCoroutineContext) { result, element -> + // We're appending new context element -- we have to copy it, otherwise it may be shared with others + if (element is CopyableThreadContextElement<*>) { + return@fold result + element.copyForChild() + } + return@fold result + element + } } + return folded + leftoverContext } /** @@ -77,7 +124,7 @@ internal actual inline fun withContinuationContext(continuation: Continuatio internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { if (this !is CoroutineStackFrame) return null /* - * Fast-path to detect whether we have unispatched coroutine at all in our stack. + * Fast-path to detect whether we have undispatched coroutine at all in our stack. * * Implementation note. * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: @@ -88,8 +135,8 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont * Both options should work, but it requires more careful studying of the performance * and, mostly, maintainability impact. */ - val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null - if (!potentiallyHasUndispatchedCorotuine) return null + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null val completion = undispatchedCompletion() completion?.saveThreadContext(context, oldValue) return completion diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index 1a960699c7..d2b6b6b988 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -80,7 +80,7 @@ public interface ThreadContextElement : CoroutineContext.Element { /** * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. * - * When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement] + * When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement] * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. * * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine @@ -99,6 +99,7 @@ public interface ThreadContextElement : CoroutineContext.Element { * ``` * class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement { * companion object Key : CoroutineContext.Key + * * override val key: CoroutineContext.Key = Key * * override fun updateThreadContext(context: CoroutineContext): TraceData? { @@ -111,24 +112,35 @@ public interface ThreadContextElement : CoroutineContext.Element { * traceThreadLocal.set(oldState) * } * - * override fun copyForChildCoroutine(): CopyableThreadContextElement { + * override fun copyForChild(): TraceContextElement { * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes * // ThreadLocal writes between resumption of the parent coroutine and the launch of the * // child coroutine visible to the child. * return TraceContextElement(traceThreadLocal.get()?.copy()) * } + * + * override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext { + * // Merge operation defines how to handle situations when both + * // the parent coroutine has an element in the context and + * // an element with the same key was also + * // explicitly passed to the child coroutine. + * // If merging does not require special behavior, + * // the copy of the element can be returned. + * return TraceContextElement(traceThreadLocal.get()?.copy()) + * } * } * ``` * - * A coroutine using this mechanism can safely call Java code that assumes it's called using a - * `Thread`. + * A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's + * value is installed into the target thread local. */ +@DelicateCoroutinesApi @ExperimentalCoroutinesApi public interface CopyableThreadContextElement : ThreadContextElement { /** * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child - * coroutine's context that is under construction. + * coroutine's context that is under construction if the added context does not contain an element with the same [key]. * * This function is called on the element each time a new coroutine inherits a context containing it, * and the returned value is folded into the context given to the child. @@ -136,7 +148,16 @@ public interface CopyableThreadContextElement : ThreadContextElement { * Since this method is called whenever a new coroutine is launched in a context containing this * [CopyableThreadContextElement], implementations are performance-sensitive. */ - public fun copyForChildCoroutine(): CopyableThreadContextElement + public fun copyForChild(): CopyableThreadContextElement + + /** + * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * coroutine's context that is under construction if the added context does contain an element with the same [key]. + * + * This method is invoked on the original element, accepting as the parameter + * the element that is supposed to overwrite it. + */ + public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext } /** diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index baba4aa8e6..ec45406bce 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -126,8 +126,7 @@ class ThreadContextElementTest : TestBase() { @Test fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest { newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { - val startData = MyData() - withContext(it + CopyForChildCoroutineElement(startData)) { + withContext(it + CopyForChildCoroutineElement(MyData())) { val forBlockData = MyData() myThreadLocal.setForBlock(forBlockData) { assertSame(myThreadLocal.get(), forBlockData) @@ -153,7 +152,7 @@ class ThreadContextElementTest : TestBase() { assertSame(myThreadLocal.get(), forBlockData) } } - assertSame(myThreadLocal.get(), startData) // Asserts value was restored. + assertNull(myThreadLocal.get()) // Asserts value was restored to its origin } } } @@ -187,7 +186,7 @@ class MyElement(val data: MyData) : ThreadContextElement { } /** - * A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine]. + * A [ThreadContextElement] that implements copy semantics in [copyForChild]. */ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { companion object Key : CoroutineContext.Key @@ -201,6 +200,10 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle return oldState } + override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement { + TODO("Not used in tests") + } + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { myThreadLocal.set(oldState) } @@ -216,7 +219,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the * thread and calls [restoreThreadContext]. */ - override fun copyForChildCoroutine(): CopyableThreadContextElement { + override fun copyForChild(): CopyForChildCoroutineElement { return CopyForChildCoroutineElement(myThreadLocal.get()) } } diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt new file mode 100644 index 0000000000..34e5955fd7 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt @@ -0,0 +1,134 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.coroutines.* +import kotlin.test.* + +class ThreadContextMutableCopiesTest : TestBase() { + companion object { + val threadLocalData: ThreadLocal> = ThreadLocal.withInitial { ArrayList() } + } + + class MyMutableElement( + val mutableData: MutableList + ) : CopyableThreadContextElement> { + + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key<*> + get() = Key + + override fun updateThreadContext(context: CoroutineContext): MutableList { + val st = threadLocalData.get() + threadLocalData.set(mutableData) + return st + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: MutableList) { + threadLocalData.set(oldState) + } + + override fun copyForChild(): MyMutableElement { + return MyMutableElement(ArrayList(mutableData)) + } + + override fun mergeForChild(overwritingElement: CoroutineContext.Element): MyMutableElement { + overwritingElement as MyMutableElement // <- app-specific, may be another subtype + return MyMutableElement((mutableData.toSet() + overwritingElement.mutableData).toMutableList()) + } + } + + @Test + fun testDataIsCopied() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + val data = threadLocalData.get() + expect(1) + launch(root) { + assertNotSame(data, threadLocalData.get()) + assertEquals(data, threadLocalData.get()) + finish(2) + } + } + } + + @Test + fun testDataIsNotOverwritten() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + expect(1) + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + launch(Dispatchers.Default + root) { + assertEquals(listOf("X", "Y"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsMerged() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + expect(1) + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + launch(Dispatchers.Default + MyMutableElement(mutableListOf("Z"))) { + assertEquals(listOf("X", "Y", "Z"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsNotOverwrittenWithContext() = runTest { + val root = MyMutableElement(ArrayList()) + runBlocking(root) { + val originalData = threadLocalData.get() + threadLocalData.get().add("X") + expect(1) + launch { + threadLocalData.get().add("Y") + // Note here, +root overwrites the data + withContext(Dispatchers.Default + root) { + assertEquals(listOf("X", "Y"), threadLocalData.get()) + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } + } + } + + @Test + fun testDataIsCopiedForRunBlocking() = runTest { + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + runBlocking(root) { + assertNotSame(originalData, threadLocalData.get()) + } + } + + @Test + fun testDataIsCopiedForCoroutine() = runTest { + val root = MyMutableElement(ArrayList()) + val originalData = root.mutableData + expect(1) + launch(root) { + assertNotSame(originalData, threadLocalData.get()) + finish(2) + } + } +} diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index e1e29581a7..6e2dac1a29 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -49,6 +49,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): combined + (DefaultDelay as CoroutineContext.Element) else combined } +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + return this + addedContext +} + // No debugging facilities on native internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()