Skip to content

Commit

Permalink
Implemented CopyableThreadContextElement with a `copyForChildCorout…
Browse files Browse the repository at this point in the history
…ine()`. (#3025)


* This is a `ThreadContextElement` that is copy-constructed when a new coroutine
is created and inherits the context.


Co-authored-by: Tyson Henning <yorick@google.com>

Fixes #2839
  • Loading branch information
qwwdfsad committed Nov 17, 2021
1 parent ae0c842 commit 603bd79
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 2 deletions.
11 changes: 11 additions & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Expand Up @@ -140,6 +140,17 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run
public fun <init> (Ljava/lang/String;Ljava/lang/Throwable;)V
}

public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement {
public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement;
}

public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls {
public static fun fold (Lkotlinx/coroutines/CopyableThreadContextElement;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
public static fun get (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element;
public static fun minusKey (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext;
public static fun plus (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
}

public abstract interface class kotlinx/coroutines/CopyableThrowable {
public abstract fun createCopy ()Ljava/lang/Throwable;
}
Expand Down
21 changes: 20 additions & 1 deletion kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Expand Up @@ -16,12 +16,31 @@ import kotlin.coroutines.jvm.internal.CoroutineStackFrame
*/
@ExperimentalCoroutinesApi
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = coroutineContext + context
val combined = coroutineContext.foldCopiesForChildCoroutine() + context
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
debug + Dispatchers.Default else debug
}

/**
* Returns the [CoroutineContext] for a child coroutine to inherit.
*
* If any [CopyableThreadContextElement] is in the [this], calls
* [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext]
* by folding the returned copied elements into [this].
*
* Returns [this] if `this` has zero [CopyableThreadContextElement] in it.
*/
private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext {
val hasToCopy = fold(false) { result, it ->
result || it is CopyableThreadContextElement<*>
}
if (!hasToCopy) return this
return fold<CoroutineContext>(EmptyCoroutineContext) { combined, it ->
combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it
}
}

/**
* Executes a block using a given coroutine context.
*/
Expand Down
63 changes: 63 additions & 0 deletions kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
Expand Up @@ -77,6 +77,69 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}

/**
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
*
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
*
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
* will be visible to _itself_ and any child coroutine launched _after_ that write.
*
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
* launching a child coroutine will not be visible to that child coroutine.
*
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
* correctly, regardless of the coroutine's structured concurrency.
*
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
* is in a coroutine:
*
* ```
* class TraceContextElement(val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
* companion object Key : CoroutineContext.Key<ThreadTraceContextElement>
* override val key: CoroutineContext.Key<ThreadTraceContextElement>
* get() = Key
*
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
* val oldState = traceThreadLocal.get()
* traceThreadLocal.set(data)
* return oldState
* }
*
* override fun restoreThreadContext(context: CoroutineContext, oldData: TraceData?) {
* traceThreadLocal.set(oldState)
* }
*
* override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
* // child coroutine visible to the child.
* return CopyForChildCoroutineElement(traceThreadLocal.get())
* }
* }
* ```
*
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
* `Thread`.
*/
@ExperimentalCoroutinesApi
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction.
*
* This function is called on the element each time a new coroutine inherits a context containing it,
* and the returned value is folded into the context given to the child.
*
* Since this method is called whenever a new coroutine is launched in a context containing this
* [CopyableThreadContextElement], implementations are performance-sensitive.
*/
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
}

/**
* Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
* maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
Expand Down
130 changes: 129 additions & 1 deletion kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt
Expand Up @@ -54,7 +54,6 @@ class ThreadContextElementTest : TestBase() {
assertNull(myThreadLocal.get())
}


@Test
fun testWithContext() = runTest {
expect(1)
Expand Down Expand Up @@ -86,6 +85,78 @@ class ThreadContextElementTest : TestBase() {

finish(7)
}

@Test
fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest {
var parentElement: MyElement? = null
var inheritedElement: MyElement? = null

newSingleThreadContext("withContext").use {
withContext(it + MyElement(MyData())) {
parentElement = coroutineContext[MyElement.Key]
launch {
inheritedElement = coroutineContext[MyElement.Key]
}
}
}

assertSame(inheritedElement, parentElement,
"Inner and outer coroutines did not have the same object reference to a" +
" ThreadContextElement that did not override `copyForChildCoroutine()`")
}

@Test
fun testCopyableElementCopiedOnLaunch() = runTest {
var parentElement: CopyForChildCoroutineElement? = null
var inheritedElement: CopyForChildCoroutineElement? = null

newSingleThreadContext("withContext").use {
withContext(it + CopyForChildCoroutineElement(MyData())) {
parentElement = coroutineContext[CopyForChildCoroutineElement.Key]
launch {
inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key]
}
}
}

assertNotSame(inheritedElement, parentElement,
"Inner coroutine did not copy its copyable ThreadContextElement.")
}

@Test
fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
val startData = MyData()
withContext(it + CopyForChildCoroutineElement(startData)) {
val forBlockData = MyData()
myThreadLocal.setForBlock(forBlockData) {
assertSame(myThreadLocal.get(), forBlockData)
launch {
assertSame(myThreadLocal.get(), forBlockData)
}
launch {
assertSame(myThreadLocal.get(), forBlockData)
// Modify value in child coroutine. Writes to the ThreadLocal and
// the (copied) ThreadLocalElement's memory are not visible to peer or
// ancestor coroutines, so this write is both threadsafe and coroutinesafe.
val innerCoroutineData = MyData()
myThreadLocal.setForBlock(innerCoroutineData) {
assertSame(myThreadLocal.get(), innerCoroutineData)
}
assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
}
launch {
val innerCoroutineData = MyData()
myThreadLocal.setForBlock(innerCoroutineData) {
assertSame(myThreadLocal.get(), innerCoroutineData)
}
assertSame(myThreadLocal.get(), forBlockData)
}
}
assertSame(myThreadLocal.get(), startData) // Asserts value was restored.
}
}
}
}

class MyData
Expand Down Expand Up @@ -114,3 +185,60 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
myThreadLocal.set(oldState)
}
}

/**
* A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine].
*/
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>

override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
get() = Key

override fun updateThreadContext(context: CoroutineContext): MyData? {
val oldState = myThreadLocal.get()
myThreadLocal.set(data)
return oldState
}

override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
myThreadLocal.set(oldState)
}

/**
* At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new
* child coroutine, and that value is copied to a new, unique, ThreadContextElement memory
* reference for the child coroutine to use uniquely.
*
* n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not
* the value initially passed to the ThreadContextElement in order to reflect writes made to the
* ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes
* will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
* thread and calls [restoreThreadContext].
*/
override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
return CopyForChildCoroutineElement(myThreadLocal.get())
}
}

/**
* Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block].
*
* When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a
* [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically
* at every statement reached, whether that statement is reached immediately, across suspend and
* redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal`
* by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal`
* by the parent coroutine _after_ launching a child coroutine will not be visible to that child
* coroutine.
*/
private inline fun <ThreadLocalT, OutputT> ThreadLocal<ThreadLocalT>.setForBlock(
value: ThreadLocalT,
crossinline block: () -> OutputT
) {
val priorValue = get()
set(value)
block()
set(priorValue)
}

0 comments on commit 603bd79

Please sign in to comment.