From 8befcd66b02a3af23f55d45240bdb19ab48a862f Mon Sep 17 00:00:00 2001 From: Trol Date: Wed, 22 Apr 2020 21:39:18 +0800 Subject: [PATCH] Support thread interrupting blocking functions (#1947) This is implementation of issue #1947 Signed-off-by: Trol --- .../api/kotlinx-coroutines-core.api | 5 + .../jvm/src/Interruptible.kt | 162 +++++++++++++++++ .../jvm/test/InterruptibleTest.kt | 163 ++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 kotlinx-coroutines-core/jvm/src/Interruptible.kt create mode 100644 kotlinx-coroutines-core/jvm/test/InterruptibleTest.kt diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 54e355ec37..f0383806d1 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -328,6 +328,11 @@ public final class kotlinx/coroutines/GlobalScope : kotlinx/coroutines/Coroutine public abstract interface annotation class kotlinx/coroutines/InternalCoroutinesApi : java/lang/annotation/Annotation { } +public final class kotlinx/coroutines/InterruptibleKt { + public static final fun runInterruptible (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun runInterruptible$default (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; +} + public abstract interface class kotlinx/coroutines/Job : kotlin/coroutines/CoroutineContext$Element { public static final field Key Lkotlinx/coroutines/Job$Key; public abstract fun attachChild (Lkotlinx/coroutines/ChildJob;)Lkotlinx/coroutines/ChildHandle; diff --git a/kotlinx-coroutines-core/jvm/src/Interruptible.kt b/kotlinx-coroutines-core/jvm/src/Interruptible.kt new file mode 100644 index 0000000000..bee340134b --- /dev/null +++ b/kotlinx-coroutines-core/jvm/src/Interruptible.kt @@ -0,0 +1,162 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn + +/** + * Makes a blocking code block cancellable (become a cancellation point of the coroutine). + * + * The blocking code block will be interrupted and this function will throw [CancellationException] + * if the coroutine is cancelled. + * + * Example: + * ``` + * GlobalScope.launch(Dispatchers.IO) { + * async { + * // This function will throw [CancellationException]. + * runInterruptible { + * doSomethingUseful() + * + * // This blocking procedure will be interrupted when this coroutine is canceled + * // by Exception thrown by the below async block. + * doSomethingElseUsefulInterruptible() + * } + * } + * + * async { + * delay(500L) + * throw Exception() + * } + * } + * ``` + * + * There is also an optional context parameter to this function to enable single-call conversion of + * interruptible Java methods into main-safe suspending functions like this: + * ``` + * // With one call here we are moving the call to Dispatchers.IO and supporting interruption. + * suspend fun BlockingQueue.awaitTake(): T = + * runInterruptible(Dispatchers.IO) { queue.take() } + * ``` + * + * @param context additional to [CoroutineScope.coroutineContext] context of the coroutine. + * @param block regular blocking block that will be interrupted on coroutine cancellation. + */ +public suspend fun runInterruptible( + context: CoroutineContext = EmptyCoroutineContext, + block: () -> T +): T = withContext(context) { runInterruptibleInExpectedContext(block) } + +private suspend fun runInterruptibleInExpectedContext(block: () -> T): T = + suspendCoroutineUninterceptedOrReturn sc@{ uCont -> + try { + // fast path: no job + val job = uCont.context[Job] ?: return@sc block() + // slow path + val threadState = ThreadState(job) + try { + block() + } finally { + threadState.clear() + } + } catch (e: InterruptedException) { + throw CancellationException("runInterruptible: interrupted").initCause(e) + } + } + +private const val WORKING = 0 +private const val FINISH = 1 +private const val INTERRUPTING = 2 +private const val INTERRUPTED = 3 + +private class ThreadState : CompletionHandler { + /* + === States === + + WORKING: running normally + FINISH: complete normally + INTERRUPTING: canceled, going to interrupt this thread + INTERRUPTED: this thread is interrupted + + + === Possible Transitions === + + +----------------+ remember +-------------------------+ + | WORKING | cancellation listener | WORKING | + | (thread, null) | -------------------------> | (thread, cancel handle) | + +----------------+ +-------------------------+ + | | | + | cancel cancel | | complete + | | | + V | | + +---------------+ | | + | INTERRUPTING | <--------------------------------------+ | + +---------------+ | + | | + | interrupt | + | | + V V + +---------------+ +-------------------------+ + | INTERRUPTED | | FINISH | + +---------------+ +-------------------------+ + */ + private val state: AtomicRef + + private data class State(val state: Int, val thread: Thread? = null, val cancelHandle: DisposableHandle? = null) + + // We're using a non-primary constructor instead of init block of a primary constructor here, because + // we need to `return`. + constructor (job: Job) { + state = atomic(State(WORKING, Thread.currentThread())) + // watches the job for cancellation + val cancelHandle = + job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = this) + // remembers the cancel handle or drops it + state.loop { s -> + when(s.state) { + WORKING -> if (state.compareAndSet(s, State(WORKING, s.thread, cancelHandle))) return + INTERRUPTING, INTERRUPTED -> return + FINISH -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + fun clear() { + state.loop { s -> + when(s.state) { + WORKING -> if (state.compareAndSet(s, State(FINISH))) { s.cancelHandle!!.dispose(); return } + INTERRUPTING -> { /* spin */ } + INTERRUPTED -> { Thread.interrupted(); return } // no interrupt leak + FINISH -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + override fun invoke(cause: Throwable?) = onCancel(cause) + + private inline fun onCancel(cause: Throwable?) { + state.loop { s -> + when(s.state) { + WORKING -> { + if (state.compareAndSet(s, State(INTERRUPTING))) { + s.thread!!.interrupt() + state.value = State(INTERRUPTED) + return + } + } + FINISH -> return + INTERRUPTING, INTERRUPTED -> return + else -> throw IllegalStateException("unknown state") + } + } + } +} diff --git a/kotlinx-coroutines-core/jvm/test/InterruptibleTest.kt b/kotlinx-coroutines-core/jvm/test/InterruptibleTest.kt new file mode 100644 index 0000000000..1f8dcb6dbb --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/InterruptibleTest.kt @@ -0,0 +1,163 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import java.io.IOException +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.* + +class InterruptibleTest: TestBase() { + @Test + fun testNormalRun() = runBlocking { + var result = runInterruptible { + var x = doSomethingUsefulBlocking(1, 1) + var y = doSomethingUsefulBlocking(1, 2) + x + y + } + assertEquals(3, result) + } + + @Test + fun testExceptionThrow() = runBlocking { + try { + runInterruptible { + throw TestException() + } + } catch (e: Throwable) { + assertTrue(e is TestException) + return@runBlocking + } + fail() + } + + @Test + fun testRunWithContext() = runBlocking { + var runThread = + runInterruptible (Dispatchers.IO) { + Thread.currentThread() + } + assertNotEquals(runThread, Thread.currentThread()) + } + + @Test + fun testInterrupt() { + val count = AtomicInteger(0) + try { + expect(1) + runBlocking { + launch(Dispatchers.IO) { + async { + try { + // `runInterruptible` makes a blocking block cancelable (become a cancellation point) + // by interrupting it on cancellation and throws CancellationException + runInterruptible { + try { + doSomethingUsefulBlocking(100, 1) + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + expect(3) + throw e + } + } + } catch (e: CancellationException) { + expect(4) + } + } + + async { + delay(500L) + expect(2) + throw IOException() + } + } + } + } catch (e: IOException) { + expect(5) + } + finish(6) + } + + @Test + fun testNoInterruptLeak() = runBlocking { + var interrupted = true + + var task = launch(Dispatchers.IO) { + try { + runInterruptible { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } + } finally { + interrupted = Thread.currentThread().isInterrupted + } + } + + delay(500) + task.cancel() + task.join() + assertFalse(interrupted) + } + + @Test + fun testStress() { + val REPEAT_TIMES = 2_000 + + Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher -> + val interruptLeak = AtomicBoolean(false) + val enterCount = AtomicInteger(0) + val interruptedCount = AtomicInteger(0) + val otherExceptionCount = AtomicInteger(0) + + runBlocking { + repeat(REPEAT_TIMES) { repeat -> + var job = launch(start = CoroutineStart.LAZY, context = dispatcher) { + try { + runInterruptible { + enterCount.incrementAndGet() + try { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + interruptedCount.incrementAndGet() + throw e + } + } + } catch (e: CancellationException) { + } catch (e: Throwable) { + otherExceptionCount.incrementAndGet() + } finally { + interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted) + } + } + + var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) { + job.cancel() + } + + launch (dispatcher) { + delay((REPEAT_TIMES - repeat).toLong()) + job.start() + } + + launch (dispatcher) { + delay(repeat.toLong()) + cancelJob.start() + } + } + } + + assertFalse(interruptLeak.get()) + assertEquals(enterCount.get(), interruptedCount.get()) + assertEquals(0, otherExceptionCount.get()) + } + } + + private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int { + Thread.sleep(timeUseMillis) + return result + } + + private class TestException : Exception() +}