Skip to content

Commit

Permalink
Properly preserve thread local values for coroutines that are not int…
Browse files Browse the repository at this point in the history
…ercepted with DispatchedContinuation (#3252)

* Properly preserve thread local values for coroutines that are not intercepted with DispatchedContinuation

Fixes #2930
  • Loading branch information
qwwdfsad committed Apr 18, 2022
1 parent 163a55e commit c1cd02c
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
31 changes: 31 additions & 0 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Expand Up @@ -181,6 +181,37 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
*/
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()

init {
/*
* This is a hack for a very specific case in #2930 unless #3253 is implemented.
* 'ThreadLocalStressTest' covers this change properly.
*
* The scenario this change covers is the following:
* 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
* e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
* `withContext(tlElement)` which creates `UndispatchedCoroutine`.
* 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
* and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
* do thread context element tracking.
* 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
* but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
*
* Here we detect precisely this situation and properly setup context to recover later.
*
*/
if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
/*
* We cannot just "read" the elements as there is no such API,
* so we update-restore it immediately and use the intermediate value
* as the initial state, leveraging the fact that thread context element
* is idempotent and such situations are increasingly rare.
*/
val values = updateThreadContext(context, null)
restoreThreadContext(context, values)
saveThreadContext(context, values)
}
}

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
threadStateToRecover.set(context to oldValue)
}
Expand Down
95 changes: 94 additions & 1 deletion kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt
Expand Up @@ -4,6 +4,10 @@

package kotlinx.coroutines

import kotlinx.coroutines.sync.*
import java.util.concurrent.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
import kotlin.test.*


Expand Down Expand Up @@ -63,10 +67,99 @@ class ThreadLocalStressTest : TestBase() {
withContext(threadLocal.asContextElement("foo")) {
yield()
cancel()
suspendCancellableCoroutineReusable<Unit> { }
suspendCancellableCoroutineReusable<Unit> { }
}
} finally {
assertEquals(expectedValue, threadLocal.get())
}
}

/*
* Another set of tests for undispatcheable continuations that do not require stress test multiplier.
* Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from
* `resumeWith`
*/

@Test
fun testNonDispatcheableLeak() {
repeat(100) {
doTestWithPreparation(
::doTest,
{ threadLocal.set(null) }) { threadLocal.get() == null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitial() {
repeat(100) {
doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() == "initial" }
assertEquals("initial", threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set(null) }) { threadLocal.get() == null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitialWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set("initial") }) { true /* can randomly wake up on the non-main thread */ }
// Here we are always on the main thread
assertEquals("initial", threadLocal.get())
}
}

private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isValid: () -> Boolean) {
setup()
val latch = CountDownLatch(1)
testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) {
if (!isValid()) {
Thread.currentThread().uncaughtExceptionHandler.uncaughtException(
Thread.currentThread(),
IllegalStateException("Unexpected error: thread local was not cleaned")
)
}
latch.countDown()
})
latch.await()
}

private suspend fun doTest() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}

private suspend fun doTestWithContextSwitch() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
GlobalScope.launch { }.join()
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}
}

0 comments on commit c1cd02c

Please sign in to comment.