diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index e08b805295..f088430f7b 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -181,6 +181,37 @@ internal actual class UndispatchedCoroutineactual constructor ( */ private var threadStateToRecover = ThreadLocal>() + init { + /* + * This is a hack for a very specific case in #2930 unless #3253 is implemented. + * 'ThreadLocalStressTest' covers this change properly. + * + * The scenario this change covers is the following: + * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function, + * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking + * `withContext(tlElement)` which creates `UndispatchedCoroutine`. + * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()` + * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both + * do thread context element tracking. + * 3) So thread locals never got chance to get properly set up via `saveThreadContext`, + * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`. + * + * Here we detect precisely this situation and properly setup context to recover later. + * + */ + if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) { + /* + * We cannot just "read" the elements as there is no such API, + * so we update-restore it immediately and use the intermediate value + * as the initial state, leveraging the fact that thread context element + * is idempotent and such situations are increasingly rare. + */ + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + threadStateToRecover.set(context to values) + } + } + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { threadStateToRecover.set(context to oldValue) } diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt index f9941d0215..169b077674 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt @@ -4,6 +4,10 @@ package kotlinx.coroutines +import kotlinx.coroutines.sync.* +import java.util.concurrent.* +import kotlin.coroutines.* +import kotlin.coroutines.intrinsics.* import kotlin.test.* @@ -63,10 +67,99 @@ class ThreadLocalStressTest : TestBase() { withContext(threadLocal.asContextElement("foo")) { yield() cancel() - suspendCancellableCoroutineReusable { } + suspendCancellableCoroutineReusable { } } } finally { assertEquals(expectedValue, threadLocal.get()) } } + + /* + * Another set of tests for undispatcheable continuations that do not require stress test multiplier. + * Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from + * `resumeWith` + */ + + @Test + fun testNonDispatcheableLeak() { + repeat(100) { + doTestWithPreparation( + ::doTest, + { threadLocal.set(null) }) { threadLocal.get() != null } + assertNull(threadLocal.get()) + } + } + + @Test + fun testNonDispatcheableLeakWithInitial() { + repeat(100) { + doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() != "initial" } + assertEquals("initial", threadLocal.get()) + } + } + + @Test + fun testNonDispatcheableLeakWithContextSwitch() { + repeat(100) { + doTestWithPreparation( + ::doTestWithContextSwitch, + { threadLocal.set(null) }) { threadLocal.get() != null } + assertNull(threadLocal.get()) + } + } + + @Test + fun testNonDispatcheableLeakWithInitialWithContextSwitch() { + repeat(100) { + doTestWithPreparation( + ::doTestWithContextSwitch, + { threadLocal.set("initial") }) { false /* can randomly wake up on the non-main thread */ } + // Here we are always on the main thread + assertEquals("initial", threadLocal.get()) + } + } + + private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isInvalid: () -> Boolean) { + setup() + val latch = CountDownLatch(1) + testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) { + if (isInvalid()) { + Thread.currentThread().uncaughtExceptionHandler.uncaughtException( + Thread.currentThread(), + IllegalStateException("Unexpected error: thread local was not cleaned") + ) + } + latch.countDown() + }) + latch.await() + } + + private suspend fun doTest() { + withContext(threadLocal.asContextElement("foo")) { + try { + coroutineScope { + val semaphore = Semaphore(1, 1) + cancel() + semaphore.acquire() + } + } catch (e: CancellationException) { + // Ignore cancellation + } + } + } + + private suspend fun doTestWithContextSwitch() { + withContext(threadLocal.asContextElement("foo")) { + try { + coroutineScope { + val semaphore = Semaphore(1, 1) + GlobalScope.launch { }.join() + cancel() + semaphore.acquire() + } + } catch (e: CancellationException) { + // Ignore cancellation + } + } + } }