diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 54e355ec37..fc278f9182 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -86,6 +86,10 @@ public final class kotlinx/coroutines/CancellableContinuationKt { public static final fun suspendCancellableCoroutine (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class kotlinx/coroutines/CancellationPointKt { + public static final fun interruptible (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public abstract interface class kotlinx/coroutines/ChildHandle : kotlinx/coroutines/DisposableHandle { public abstract fun childCancelled (Ljava/lang/Throwable;)Z } diff --git a/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt b/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt new file mode 100644 index 0000000000..7c5118bb07 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/src/CancellationPoint.kt @@ -0,0 +1,145 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn + +/** + * Makes a blocking code block cancellable (become a cancellation point of the coroutine). + * + * The blocking code block will be interrupted and this function will throw [CancellationException] + * if the coroutine is cancelled. + * + * Example: + * ``` + * GlobalScope.launch(Dispatchers.IO) { + * async { + * // This function will throw [CancellationException]. + * runInterruptible { + * doSomethingUseful() + * + * // This blocking procedure will be interrupted when this coroutine is canceled + * // by Exception thrown by the below async block. + * doSomethingElseUsefulInterruptible() + * } + * } + * + * async { + * delay(500L) + * throw Exception() + * } + * } + * ``` + * + * There is also an optional context parameter to this function to enable single-call conversion of + * interruptible Java methods into main-safe suspending functions like this: + * ``` + * // With one call here we are moving the call to Dispatchers.IO and supporting interruption. + * suspend fun BlockingQueue.awaitTake(): T = + * runInterruptible(Dispatchers.IO) { queue.take() } + * ``` + * + * @param context additional to [CoroutineScope.coroutineContext] context of the coroutine. + * @param block regular blocking block that will be interrupted on coroutine cancellation. + */ +public suspend fun runInterruptible( + context: CoroutineContext = EmptyCoroutineContext, + block: () -> T +): T { + // fast path: empty context + if (context === EmptyCoroutineContext) { return runInterruptibleInExpectedContext(block) } + // slow path: + return withContext(context) { runInterruptibleInExpectedContext(block) } +} + +private suspend fun runInterruptibleInExpectedContext(block: () -> T): T = + suspendCoroutineUninterceptedOrReturn sc@{ uCont -> + try { + // fast path: no job + val job = uCont.context[Job] ?: return@sc block() + // slow path + val threadState = ThreadState().apply { initInterrupt(job) } + try { + block() + } finally { + threadState.clearInterrupt() + } + } catch (e: InterruptedException) { + throw CancellationException() + } + } + +private class ThreadState { + + fun initInterrupt(job: Job) { + // starts with Init + if (state.value !== Init) throw IllegalStateException("impossible state") + // remembers this running thread + state.value = Working(Thread.currentThread(), null) + // watches the job for cancellation + val cancelHandle = + job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = CancelHandler()) + // remembers the cancel handle or drops it + state.loop { s -> + when { + s is Working -> if (state.compareAndSet(s, Working(s.thread, cancelHandle))) return + s === Interrupting || s === Interrupted -> return + s === Init || s === Finish -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + fun clearInterrupt() { + state.loop { s -> + when { + s is Working -> if (state.compareAndSet(s, Finish)) { s.cancelHandle!!.dispose(); return } + s === Interrupting -> Thread.yield() // eases the thread + s === Interrupted -> { Thread.interrupted(); return } // no interrupt leak + s === Init || s === Finish -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + + private inner class CancelHandler : CompletionHandler { + override fun invoke(cause: Throwable?) { + state.loop { s -> + when { + s is Working -> { + if (state.compareAndSet(s, Interrupting)) { + s.thread!!.interrupt() + state.value = Interrupted + return + } + } + s === Finish -> return + s === Interrupting || s === Interrupted -> return + s === Init -> throw IllegalStateException("impossible state") + else -> throw IllegalStateException("unknown state") + } + } + } + } + + private val state: AtomicRef = atomic(Init) + + private interface State + // initial state + private object Init : State + // cancellation watching is setup and/or the continuation is running + private data class Working(val thread: Thread?, val cancelHandle: DisposableHandle?) : State + // the continuation done running without interruption + private object Finish : State + // interrupting this thread + private object Interrupting: State + // done interrupting + private object Interrupted: State +} diff --git a/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt b/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt new file mode 100644 index 0000000000..33722992fe --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/InterruptibleCancellationPointTest.kt @@ -0,0 +1,172 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import java.io.IOException +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.* + +class InterruptibleCancellationPointTest: TestBase() { + + @Test + fun testNormalRun() = runBlocking { + var result = runInterruptible { + var x = doSomethingUsefulBlocking(1, 1) + var y = doSomethingUsefulBlocking(1, 2) + x + y + } + assertEquals(3, result) + } + + @Test + fun testExceptionThrow() = runBlocking { + val exception = Exception() + try { + runInterruptible { + throw exception + } + } catch (e: Throwable) { + assertEquals(exception, e) + return@runBlocking + } + fail() + } + + @Test + fun testRunWithContext() = runBlocking { + var runThread = + runInterruptible (Dispatchers.IO) { + Thread.currentThread() + } + assertNotEquals(runThread, Thread.currentThread()) + } + + @Test + fun testRunWithContextFastPath() = runBlocking { + var runThread : Thread = + runInterruptible { + Thread.currentThread() + } + assertEquals(runThread, Thread.currentThread()) + } + + @Test + fun testInterrupt() { + val count = AtomicInteger(0) + try { + expect(1) + runBlocking { + launch(Dispatchers.IO) { + async { + try { + // `runInterruptible` makes a blocking block cancelable (become a cancellation point) + // by interrupting it on cancellation and throws CancellationException + runInterruptible { + try { + doSomethingUsefulBlocking(100, 1) + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + expect(3) + throw e + } + } + } catch (e: CancellationException) { + expect(4) + } + } + + async { + delay(500L) + expect(2) + throw IOException() + } + } + } + } catch (e: IOException) { + expect(5) + } + finish(6) + } + + @Test + fun testNoInterruptLeak() = runBlocking { + var interrupted = true + + var task = launch(Dispatchers.IO) { + try { + runInterruptible { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } + } finally { + interrupted = Thread.currentThread().isInterrupted + } + } + + delay(500) + task.cancel() + task.join() + assertFalse(interrupted) + } + + @Test + fun testStress() { + val REPEAT_TIMES = 2_000 + + Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher -> + val interruptLeak = AtomicBoolean(false) + val enterCount = AtomicInteger(0) + val interruptedCount = AtomicInteger(0) + val otherExceptionCount = AtomicInteger(0) + + runBlocking { + repeat(REPEAT_TIMES) { repeat -> + var job = launch(start = CoroutineStart.LAZY, context = dispatcher) { + try { + runInterruptible { + enterCount.incrementAndGet() + try { + doSomethingUsefulBlocking(Long.MAX_VALUE, 0) + } catch (e: InterruptedException) { + interruptedCount.incrementAndGet() + throw e + } + } + } catch (e: CancellationException) { + } catch (e: Throwable) { + otherExceptionCount.incrementAndGet() + } finally { + interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted) + } + } + + var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) { + job.cancel() + } + + launch (dispatcher) { + delay((REPEAT_TIMES - repeat).toLong()) + job.start() + } + + launch (dispatcher) { + delay(repeat.toLong()) + cancelJob.start() + } + } + } + + assertFalse(interruptLeak.get()) + assertEquals(enterCount.get(), interruptedCount.get()) + assertEquals(0, otherExceptionCount.get()) + } + } + + private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int { + Thread.sleep(timeUseMillis) + return result + } +}