Skip to content

Commit

Permalink
Properly preserve thread local values for coroutines that are not int…
Browse files Browse the repository at this point in the history
…ercepted with DispatchedContinuation

Fixes #2930
  • Loading branch information
qwwdfsad committed Apr 14, 2022
1 parent 90fa892 commit a50492a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
31 changes: 31 additions & 0 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Expand Up @@ -181,6 +181,37 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
*/
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()

init {
/*
* This is a hack for a very specific case in #2930 unless #3253 is implemented.
* 'ThreadLocalStressTest' covers this change properly.
*
* The scenario this change covers is the following:
* 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
* e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
* `withContext(tlElement)` which creates `UndispatchedCoroutine`.
* 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
* and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
* do thread context element tracking.
* 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
* but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
*
* Here we detect precisely this situation and properly setup context to recover later.
*
*/
if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
/*
* We cannot just "read" the elements as there is no such API,
* so we update-restore it immediately and use the intermediate value
* as the initial state, leveraging the fact that thread context element
* is idempotent and such situations are increasingly rare.
*/
val values = updateThreadContext(context, null)
restoreThreadContext(context, values)
threadStateToRecover.set(context to values)
}
}

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
threadStateToRecover.set(context to oldValue)
}
Expand Down
95 changes: 94 additions & 1 deletion kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt
Expand Up @@ -4,6 +4,10 @@

package kotlinx.coroutines

import kotlinx.coroutines.sync.*
import java.util.concurrent.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
import kotlin.test.*


Expand Down Expand Up @@ -63,10 +67,99 @@ class ThreadLocalStressTest : TestBase() {
withContext(threadLocal.asContextElement("foo")) {
yield()
cancel()
suspendCancellableCoroutineReusable<Unit> { }
suspendCancellableCoroutineReusable<Unit> { }
}
} finally {
assertEquals(expectedValue, threadLocal.get())
}
}

/*
* Another set of tests for undispatcheable continuations that do not require stress test multiplier.
* Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from
* `resumeWith`
*/

@Test
fun testNonDispatcheableLeak() {
repeat(100) {
doTestWithPreparation(
::doTest,
{ threadLocal.set(null) }) { threadLocal.get() != null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitial() {
repeat(100) {
doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() != "initial" }
assertEquals("initial", threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set(null) }) { threadLocal.get() != null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitialWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set("initial") }) { false /* can randomly wake up on the non-main thread */ }
// Here we are always on the main thread
assertEquals("initial", threadLocal.get())
}
}

private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isInvalid: () -> Boolean) {
setup()
val latch = CountDownLatch(1)
testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) {
if (isInvalid()) {
Thread.currentThread().uncaughtExceptionHandler.uncaughtException(
Thread.currentThread(),
IllegalStateException("Unexpected error: thread local was not cleaned")
)
}
latch.countDown()
})
latch.await()
}

private suspend fun doTest() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}

private suspend fun doTestWithContextSwitch() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
GlobalScope.launch { }.join()
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}
}

0 comments on commit a50492a

Please sign in to comment.