Skip to content

Commit

Permalink
CopyableThreadContextElement implementation (#3227)
Browse files Browse the repository at this point in the history
New approach eagerly copies corresponding elements to avoid accidental top-level reuse and also provides merge capability in case when an element is being overwritten. Merge capability is crucial in tracing scenarios to properly preserve the state of linked thread locals 


Co-authored-by: dkhalanskyjb <52952525+dkhalanskyjb@users.noreply.github.com>
  • Loading branch information
qwwdfsad and dkhalanskyjb committed Apr 4, 2022
1 parent 8133c97 commit a5dd74b
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 35 deletions.
4 changes: 3 additions & 1 deletion kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Expand Up @@ -141,7 +141,8 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run
}

public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement {
public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement;
public abstract fun copyForChild ()Lkotlinx/coroutines/CopyableThreadContextElement;
public abstract fun mergeForChild (Lkotlin/coroutines/CoroutineContext$Element;)Lkotlin/coroutines/CoroutineContext;
}

public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls {
Expand All @@ -156,6 +157,7 @@ public abstract interface class kotlinx/coroutines/CopyableThrowable {
}

public final class kotlinx/coroutines/CoroutineContextKt {
public static final fun newCoroutineContext (Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
public static final fun newCoroutineContext (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
}

Expand Down
3 changes: 2 additions & 1 deletion kotlinx-coroutines-core/common/src/Builders.common.kt
Expand Up @@ -148,7 +148,8 @@ public suspend fun <T> withContext(
return suspendCoroutineUninterceptedOrReturn sc@ { uCont ->
// compute new context
val oldContext = uCont.context
val newContext = oldContext + context
// Copy CopyableThreadContextElement if necessary
val newContext = oldContext.newCoroutineContext(context)
// always check for cancellation of new context
newContext.ensureActive()
// FAST PATH #1 -- new context is the same as the old one
Expand Down
12 changes: 10 additions & 2 deletions kotlinx-coroutines-core/common/src/CoroutineContext.common.kt
Expand Up @@ -7,11 +7,19 @@ package kotlinx.coroutines
import kotlin.coroutines.*

/**
* Creates a context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on).
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
* and copyable-thread-local facilities on JVM.
*/
public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext

/**
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
* @suppress
*/
@InternalCoroutinesApi
public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext

@PublishedApi
@Suppress("PropertyName")
internal expect val DefaultDelay: Delay
Expand Down
4 changes: 4 additions & 0 deletions kotlinx-coroutines-core/js/src/CoroutineContext.kt
Expand Up @@ -42,6 +42,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext):
combined + Dispatchers.Default else combined
}

public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
return this + addedContext
}

// No debugging facilities on JS
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()
Expand Down
87 changes: 67 additions & 20 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Expand Up @@ -9,36 +9,83 @@ import kotlin.coroutines.*
import kotlin.coroutines.jvm.internal.CoroutineStackFrame

/**
* Creates context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher nor
* [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on).
*
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
* and copyable-thread-local facilities on JVM.
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
*/
@ExperimentalCoroutinesApi
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = coroutineContext.foldCopiesForChildCoroutine() + context
val combined = foldCopies(coroutineContext, context, true)
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
debug + Dispatchers.Default else debug
}

/**
* Returns the [CoroutineContext] for a child coroutine to inherit.
*
* If any [CopyableThreadContextElement] is in the [this], calls
* [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext]
* by folding the returned copied elements into [this].
*
* Returns [this] if `this` has zero [CopyableThreadContextElement] in it.
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
* @suppress
*/
private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext {
val hasToCopy = fold(false) { result, it ->
result || it is CopyableThreadContextElement<*>
@InternalCoroutinesApi
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
/*
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
* contains copyable elements.
*/
if (!addedContext.hasCopyableElements()) return this + addedContext
return foldCopies(this, addedContext, false)
}

private fun CoroutineContext.hasCopyableElements(): Boolean =
fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }

/**
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
* The rules are the following:
* * If neither context has CTCE, the sum of two contexts is returned
* * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
* * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
* * Every CTCE from the right-hand side context that hasn't been merged is copied
* * Everything else is added to the resulting context as is.
*/
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
// Do we have something to copy left-hand side?
val hasElementsLeft = originalContext.hasCopyableElements()
val hasElementsRight = appendContext.hasCopyableElements()

// Nothing to fold, so just return the sum of contexts
if (!hasElementsLeft && !hasElementsRight) {
return originalContext + appendContext
}

var leftoverContext = appendContext
val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
if (element !is CopyableThreadContextElement<*>) return@fold result + element
// Will this element be overwritten?
val newElement = leftoverContext[element.key]
// No, just copy it
if (newElement == null) {
// For 'withContext'-like builders we do not copy as the element is not shared
return@fold result + if (isNewCoroutine) element.copyForChild() else element
}
// Yes, then first remove the element from append context
leftoverContext = leftoverContext.minusKey(element.key)
// Return the sum
@Suppress("UNCHECKED_CAST")
return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
}
if (!hasToCopy) return this
return fold<CoroutineContext>(EmptyCoroutineContext) { combined, it ->
combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it

if (hasElementsRight) {
leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
// We're appending new context element -- we have to copy it, otherwise it may be shared with others
if (element is CopyableThreadContextElement<*>) {
return@fold result + element.copyForChild()
}
return@fold result + element
}
}
return folded + leftoverContext
}

/**
Expand Down Expand Up @@ -77,7 +124,7 @@ internal actual inline fun <T> withContinuationContext(continuation: Continuatio
internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
if (this !is CoroutineStackFrame) return null
/*
* Fast-path to detect whether we have unispatched coroutine at all in our stack.
* Fast-path to detect whether we have undispatched coroutine at all in our stack.
*
* Implementation note.
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
Expand All @@ -88,8 +135,8 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont
* Both options should work, but it requires more careful studying of the performance
* and, mostly, maintainability impact.
*/
val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null
if (!potentiallyHasUndispatchedCorotuine) return null
val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
if (!potentiallyHasUndispatchedCoroutine) return null
val completion = undispatchedCompletion()
completion?.saveThreadContext(context, oldValue)
return completion
Expand Down
33 changes: 27 additions & 6 deletions kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
Expand Up @@ -80,7 +80,7 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
/**
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
*
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
* When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement]
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
*
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
Expand All @@ -99,6 +99,7 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
* ```
* class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
* companion object Key : CoroutineContext.Key<TraceContextElement>
*
* override val key: CoroutineContext.Key<TraceContextElement> = Key
*
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
Expand All @@ -111,32 +112,52 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
* traceThreadLocal.set(oldState)
* }
*
* override fun copyForChildCoroutine(): CopyableThreadContextElement<TraceData?> {
* override fun copyForChild(): TraceContextElement {
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
* // child coroutine visible to the child.
* return TraceContextElement(traceThreadLocal.get()?.copy())
* }
*
* override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
* // Merge operation defines how to handle situations when both
* // the parent coroutine has an element in the context and
* // an element with the same key was also
* // explicitly passed to the child coroutine.
* // If merging does not require special behavior,
* // the copy of the element can be returned.
* return TraceContextElement(traceThreadLocal.get()?.copy())
* }
* }
* ```
*
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
* `Thread`.
* A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's
* value is installed into the target thread local.
*/
@DelicateCoroutinesApi
@ExperimentalCoroutinesApi
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction.
* coroutine's context that is under construction if the added context does not contain an element with the same [key].
*
* This function is called on the element each time a new coroutine inherits a context containing it,
* and the returned value is folded into the context given to the child.
*
* Since this method is called whenever a new coroutine is launched in a context containing this
* [CopyableThreadContextElement], implementations are performance-sensitive.
*/
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
public fun copyForChild(): CopyableThreadContextElement<S>

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction if the added context does contain an element with the same [key].
*
* This method is invoked on the original element, accepting as the parameter
* the element that is supposed to overwrite it.
*/
public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext
}

/**
Expand Down
13 changes: 8 additions & 5 deletions kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt
Expand Up @@ -126,8 +126,7 @@ class ThreadContextElementTest : TestBase() {
@Test
fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
val startData = MyData()
withContext(it + CopyForChildCoroutineElement(startData)) {
withContext(it + CopyForChildCoroutineElement(MyData())) {
val forBlockData = MyData()
myThreadLocal.setForBlock(forBlockData) {
assertSame(myThreadLocal.get(), forBlockData)
Expand All @@ -153,7 +152,7 @@ class ThreadContextElementTest : TestBase() {
assertSame(myThreadLocal.get(), forBlockData)
}
}
assertSame(myThreadLocal.get(), startData) // Asserts value was restored.
assertNull(myThreadLocal.get()) // Asserts value was restored to its origin
}
}
}
Expand Down Expand Up @@ -187,7 +186,7 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
}

/**
* A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine].
* A [ThreadContextElement] that implements copy semantics in [copyForChild].
*/
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
Expand All @@ -201,6 +200,10 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
return oldState
}

override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement {
TODO("Not used in tests")
}

override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
myThreadLocal.set(oldState)
}
Expand All @@ -216,7 +219,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
* will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
* thread and calls [restoreThreadContext].
*/
override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
override fun copyForChild(): CopyForChildCoroutineElement {
return CopyForChildCoroutineElement(myThreadLocal.get())
}
}
Expand Down

0 comments on commit a5dd74b

Please sign in to comment.