diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index f908f964d4..0368cf1cc0 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -328,6 +328,11 @@ public final class kotlinx/coroutines/GlobalScope : kotlinx/coroutines/Coroutine public abstract interface annotation class kotlinx/coroutines/InternalCoroutinesApi : java/lang/annotation/Annotation { } +public final class kotlinx/coroutines/InterruptibleKt { + public static final fun runInterruptible (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun runInterruptible$default (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; +} + public abstract interface class kotlinx/coroutines/Job : kotlin/coroutines/CoroutineContext$Element { public static final field Key Lkotlinx/coroutines/Job$Key; public abstract fun attachChild (Lkotlinx/coroutines/ChildJob;)Lkotlinx/coroutines/ChildHandle; diff --git a/kotlinx-coroutines-core/jvm/src/Interruptible.kt b/kotlinx-coroutines-core/jvm/src/Interruptible.kt new file mode 100644 index 0000000000..a4144d87d6 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/src/Interruptible.kt @@ -0,0 +1,164 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.atomicfu.* +import kotlin.coroutines.* + +/** + * Calls the specified [block] with a given coroutine context in a interruptible manner. + * The blocking code block will be interrupted and this function will throw [CancellationException] + * if the coroutine is cancelled. + * The specified [coroutineContext] directly translates into [withContext] argument. + * + * Example: + * ``` + * val blockingJob = launch { + * // This function will throw CancellationException + * runInterruptible(Dispatchers.IO) { + * // This blocking procedure will be interrupted when this coroutine is canceled + * doSomethingElseUsefulInterruptible() + * } + * } + * + * delay(500L) + * blockingJob.cancel() // Interrupt blocking call + * } + * ``` + * + * There is also an optional context parameter to this function to enable single-call conversion of + * interruptible Java methods into suspending functions like this: + * ``` + * // With one call here we are moving the call to Dispatchers.IO and supporting interruption. + * suspend fun BlockingQueue.awaitTake(): T = + * runInterruptible(Dispatchers.IO) { queue.take() } + * ``` + */ +public suspend fun runInterruptible( + context: CoroutineContext = EmptyCoroutineContext, + block: () -> T +): T = withContext(context) { + runInterruptibleInExpectedContext(block) +} + +private suspend fun runInterruptibleInExpectedContext(block: () -> T): T { + try { + // No job -> no cancellation + val job = coroutineContext[Job] ?: return block() + val threadState = ThreadState(job) + try { + return block() + } finally { + threadState.clearInterrupt() + } + } catch (e: InterruptedException) { + throw CancellationException("Blocking call was interrupted due to parent cancellation").initCause(e) + } +} + +private const val WORKING = 0 +private const val FINISHED = 1 +private const val INTERRUPTING = 2 +private const val INTERRUPTED = 3 + +private class ThreadState : CompletionHandler { + /* + === States === + + WORKING: running normally + FINISH: complete normally + INTERRUPTING: canceled, going to interrupt this thread + INTERRUPTED: this thread is interrupted + + === Possible Transitions === + + +----------------+ register job +-------------------------+ + | WORKING | cancellation listener | WORKING | + | (thread, null) | -------------------------> | (thread, cancel handle) | + +----------------+ +-------------------------+ + | | | + | cancel cancel | | complete + | | | + V | | + +---------------+ | | + | INTERRUPTING | <--------------------------------------+ | + +---------------+ | + | | + | interrupt | + | | + V V + +---------------+ +-------------------------+ + | INTERRUPTED | | FINISHED | + +---------------+ +-------------------------+ + */ + private val state: AtomicRef = atomic(State(WORKING, null)) + private val targetThread = Thread.currentThread() + + private data class State(@JvmField val state: Int, @JvmField val cancelHandle: DisposableHandle?) + + // We're using a non-primary constructor instead of init block of a primary constructor here, because + // we need to `return`. + constructor(job: Job) { + // Register cancellation handler + val cancelHandle = + job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = this) + // Either we successfully stored it or it was immediately cancelled + state.loop { s -> + when (s.state) { + // Happy-path, move forward + WORKING -> if (state.compareAndSet(s, State(WORKING, cancelHandle))) return + // Immediately cancelled, just continue + INTERRUPTING, INTERRUPTED -> return + else -> throw IllegalStateException("Illegal state $s") + } + } + } + + fun clearInterrupt() { + /* + * Do not allow to untriggered interrupt to leak + */ + state.loop { s -> + when (s.state) { + WORKING -> if (state.compareAndSet(s, State(FINISHED, null))) { + s.cancelHandle?.dispose() + return + } + INTERRUPTING -> { + /* + * Spin, cancellation mechanism is interrupting our thread right now + * and we have to wait it and then clear interrupt status + */ + } + INTERRUPTED -> { + // Clear it and bail out + Thread.interrupted(); + return + } + else -> error("Illegal state: $s") + } + } + } + + // Cancellation handler + override fun invoke(cause: Throwable?) { + state.loop { s -> + when (s.state) { + // Working -> try to transite state and interrupt the thread + WORKING -> { + if (state.compareAndSet(s, State(INTERRUPTING, null))) { + targetThread.interrupt() + state.value = State(INTERRUPTED, null) + return + } + } + // Finished -- runInterruptible is already complete + FINISHED -> return + INTERRUPTING, INTERRUPTED -> return + else -> error("Illegal state: $s") + } + } + } +} diff --git a/kotlinx-coroutines-core/jvm/test/RunInterruptibleStressTest.kt b/kotlinx-coroutines-core/jvm/test/RunInterruptibleStressTest.kt new file mode 100644 index 0000000000..03c7c6ecb8 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/RunInterruptibleStressTest.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import java.util.concurrent.atomic.* +import kotlin.test.* + +class RunInterruptibleStressTest : TestBase() { + + private val dispatcher = Dispatchers.IO + private val REPEAT_TIMES = 1000 * stressTestMultiplier + + @Test + fun testStress() = runBlocking { + val interruptLeak = AtomicBoolean(false) + val enterCount = AtomicInteger(0) + val interruptedCount = AtomicInteger(0) + val otherExceptionCount = AtomicInteger(0) + + repeat(REPEAT_TIMES) { repeat -> + val job = launch(dispatcher, start = CoroutineStart.LAZY) { + try { + runInterruptible { + enterCount.incrementAndGet() + try { + Thread.sleep(Long.MAX_VALUE) + } catch (e: InterruptedException) { + interruptedCount.incrementAndGet() + throw e + } + } + } catch (e: CancellationException) { + } catch (e: Throwable) { + otherExceptionCount.incrementAndGet() + } finally { + interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted) + } + } + + val cancelJob = launch(dispatcher, start = CoroutineStart.LAZY) { + job.cancel() + } + + job.start() + val canceller = launch(dispatcher) { + cancelJob.start() + } + + joinAll(job, cancelJob, canceller) + } + + assertFalse(interruptLeak.get()) + assertEquals(enterCount.get(), interruptedCount.get()) + assertEquals(0, otherExceptionCount.get()) + } +} diff --git a/kotlinx-coroutines-core/jvm/test/RunInterruptibleTest.kt b/kotlinx-coroutines-core/jvm/test/RunInterruptibleTest.kt new file mode 100644 index 0000000000..e755b17d91 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/RunInterruptibleTest.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.coroutines.channels.* +import java.io.* +import java.util.concurrent.* +import java.util.concurrent.atomic.* +import kotlin.test.* + +class RunInterruptibleTest : TestBase() { + + @Test + fun testNormalRun() = runTest { + val result = runInterruptible { + val x = 1 + val y = 2 + Thread.sleep(1) + x + y + } + assertEquals(3, result) + } + + @Test + fun testExceptionalRun() = runTest { + try { + runInterruptible { + expect(1) + throw TestException() + } + } catch (e: TestException) { + finish(2) + } + } + + @Test + fun testInterrupt() = runTest { + val latch = Channel(1) + val job = launch { + runInterruptible(Dispatchers.IO) { + expect(2) + latch.offer(Unit) + try { + Thread.sleep(10_000L) + expectUnreached() + } catch (e: InterruptedException) { + expect(4) + assertFalse { Thread.currentThread().isInterrupted } + } + } + } + + launch(start = CoroutineStart.UNDISPATCHED) { + expect(1) + latch.receive() + expect(3) + job.cancelAndJoin() + }.join() + finish(5) + } +}