diff --git a/kotlinx-coroutines-test/src/DelayController.kt b/kotlinx-coroutines-test/src/DelayController.kt index 54e9c8ae5e..3f608e9e95 100644 --- a/kotlinx-coroutines-test/src/DelayController.kt +++ b/kotlinx-coroutines-test/src/DelayController.kt @@ -2,6 +2,7 @@ package kotlinx.coroutines.test import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.channels.ConflatedBroadcastChannel /** * Control the virtual clock time of a [CoroutineDispatcher]. @@ -93,6 +94,9 @@ public interface DelayController { * * This is useful when testing functions that start a coroutine. By pausing the dispatcher assertions or * setup may be done between the time the coroutine is created and started. + * + * While in the paused block, the dispatcher will queue all dispatched coroutines and they will be resumed on + * whatever thread calls [advanceUntilIdle], [advanceTimeBy], or [runCurrent]. */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 public suspend fun pauseDispatcher(block: suspend () -> Unit) @@ -102,6 +106,9 @@ public interface DelayController { * * When paused, the dispatcher will not execute any coroutines automatically, and you must call [runCurrent] or * [advanceTimeBy], or [advanceUntilIdle] to execute coroutines. + * + * While paused, the dispatcher will queue all dispatched coroutines and they will be resumed on whatever thread + * calls [advanceUntilIdle], [advanceTimeBy], or [runCurrent]. */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 public fun pauseDispatcher() @@ -112,6 +119,28 @@ public interface DelayController { * Resumed dispatchers will automatically progress through all coroutines scheduled at the current time. To advance * time and execute coroutines scheduled in the future use, one of [advanceTimeBy], * or [advanceUntilIdle]. + * + * When the dispatcher is resumed, all execution be immediate in the thread that triggered it similar to + * [Dispatchers.Unconfined]. This means that the following code will not switch back from Dispatchers.IO after + * `withContext` + * + * ``` + * runBlockingTest { + * withContext(Dispatchers.IO) { doIo() } + * // runBlockingTest is still on Dispatchers.IO here + * } + * ``` + * + * For tests that need accurate threading behavior, [pauseDispatcher] will ensure that the following test dispatches + * on a controlled thread. + * + * ``` + * runBlockingTest { + * pauseDispatcher() + * withContext(Dispatchers.IO) { doIo() } + * // runBlockingTest has returned to it's starting thread here + * } + * ``` */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 public fun resumeDispatcher() diff --git a/kotlinx-coroutines-test/src/TestBuilders.kt b/kotlinx-coroutines-test/src/TestBuilders.kt index 7ef77bd643..93349cd686 100644 --- a/kotlinx-coroutines-test/src/TestBuilders.kt +++ b/kotlinx-coroutines-test/src/TestBuilders.kt @@ -5,8 +5,12 @@ package kotlinx.coroutines.test import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.selects.select import kotlin.coroutines.* +private const val DEFAULT_WAIT_FOR_OTHER_DISPATCHERS = 30_000L + /** * Executes a [testBody] inside an immediate execution dispatcher. * @@ -38,26 +42,79 @@ import kotlin.coroutines.* * (including coroutines suspended on join/await). * * @param context additional context elements. If [context] contains [CoroutineDispatcher] or [CoroutineExceptionHandler], - * then they must implement [DelayController] and [TestCoroutineExceptionHandler] respectively. + * then they must implement [DelayController] and [TestCoroutineExceptionHandler] respectively. + * @param waitForOtherDispatchers how long to wait for other dispatchers to execute tasks asynchronously, default 30 + * seconds * @param testBody The code of the unit-test. */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 -public fun runBlockingTest(context: CoroutineContext = EmptyCoroutineContext, testBody: suspend TestCoroutineScope.() -> Unit) { +public fun runBlockingTest( + context: CoroutineContext = EmptyCoroutineContext, + waitForOtherDispatchers: Long = DEFAULT_WAIT_FOR_OTHER_DISPATCHERS, + testBody: suspend TestCoroutineScope.() -> Unit +) { val (safeContext, dispatcher) = context.checkArguments() val startingJobs = safeContext.activeJobs() - val scope = TestCoroutineScope(safeContext) - val deferred = scope.async { - scope.testBody() + + var testScope: TestCoroutineScope? = null + + val deferred = CoroutineScope(safeContext).async { + val localTestScope = TestCoroutineScope(coroutineContext) + testScope = localTestScope + localTestScope.testBody() } - dispatcher.advanceUntilIdle() - deferred.getCompletionExceptionOrNull()?.let { - throw it + + val didTimeout = deferred.waitForCompletion(waitForOtherDispatchers, dispatcher, dispatcher as IdleWaiter) + + if (deferred.isCompleted) { + deferred.getCompletionExceptionOrNull()?.let { + throw it + } } - scope.cleanupTestCoroutines() + + testScope!!.cleanupTestCoroutines() val endingJobs = safeContext.activeJobs() - if ((endingJobs - startingJobs).isNotEmpty()) { - throw UncompletedCoroutinesError("Test finished with active jobs: $endingJobs") + + // TODO: should these be separate exceptions to allow for tests to detect difference? + if (didTimeout) { + val message = """ + runBlockingTest timed out after waiting ${waitForOtherDispatchers}ms for coroutines to complete. + Active jobs after test (may be empty): $endingJobs + """.trimIndent() + throw UncompletedCoroutinesError(message) + } else if ((endingJobs - startingJobs).isNotEmpty()) { + throw UncompletedCoroutinesError("Test finished with active jobs: $endingJobs ") + } +} + +private fun Deferred.waitForCompletion(wait: Long, delayController: DelayController, park: IdleWaiter): Boolean { + var didTimeout = false + + runBlocking { + val unparkChannel = Channel(1) + val job = launch { + while(true) { + park.suspendUntilNextDispatch() + unparkChannel.send(Unit) + } + } + + try { + withTimeout(wait) { + while(!isCompleted) { + delayController.advanceUntilIdle() + select { + onAwait { Unit } + unparkChannel.onReceive { Unit } + } + } + } + } catch (exception: TimeoutCancellationException) { + didTimeout = true + } + job.cancel() } + return didTimeout } private fun CoroutineContext.activeJobs(): Set { @@ -69,18 +126,19 @@ private fun CoroutineContext.activeJobs(): Set { */ // todo: need documentation on how this extension is supposed to be used @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 -public fun TestCoroutineScope.runBlockingTest(block: suspend TestCoroutineScope.() -> Unit) = runBlockingTest(coroutineContext, block) +public fun TestCoroutineScope.runBlockingTest(waitForOtherDispatchers: Long = DEFAULT_WAIT_FOR_OTHER_DISPATCHERS, block: suspend TestCoroutineScope.() -> Unit) = runBlockingTest(coroutineContext, waitForOtherDispatchers, block) /** * Convenience method for calling [runBlockingTest] on an existing [TestCoroutineDispatcher]. */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 -public fun TestCoroutineDispatcher.runBlockingTest(block: suspend TestCoroutineScope.() -> Unit) = runBlockingTest(this, block) +public fun TestCoroutineDispatcher.runBlockingTest(waitForOtherDispatchers: Long = DEFAULT_WAIT_FOR_OTHER_DISPATCHERS, block: suspend TestCoroutineScope.() -> Unit) = runBlockingTest(this, waitForOtherDispatchers, block) private fun CoroutineContext.checkArguments(): Pair { // TODO optimize it val dispatcher = get(ContinuationInterceptor).run { this?.let { require(this is DelayController) { "Dispatcher must implement DelayController: $this" } } + this?.let { require(this is IdleWaiter) { "Dispatcher must implement IdleWaiter" } } this ?: TestCoroutineDispatcher() } diff --git a/kotlinx-coroutines-test/src/TestCoroutineDispatcher.kt b/kotlinx-coroutines-test/src/TestCoroutineDispatcher.kt index 386fc8380d..ea7afef1d7 100644 --- a/kotlinx-coroutines-test/src/TestCoroutineDispatcher.kt +++ b/kotlinx-coroutines-test/src/TestCoroutineDispatcher.kt @@ -4,11 +4,14 @@ package kotlinx.coroutines.test -import kotlinx.atomicfu.* +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update import kotlinx.coroutines.* -import kotlinx.coroutines.internal.* -import kotlin.coroutines.* -import kotlin.math.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.internal.ThreadSafeHeap +import kotlinx.coroutines.internal.ThreadSafeHeapNode +import kotlin.coroutines.CoroutineContext +import kotlin.math.max /** * [CoroutineDispatcher] that performs both immediate and lazy execution of coroutines in tests @@ -22,10 +25,14 @@ import kotlin.math.* * not execute until a call to [DelayController.runCurrent] or the virtual clock-time has been advanced via one of the * methods on [DelayController]. * + * While in immediate mode [TestCoroutineDispatcher] behaves similar to [Dispatchers.Unconfined]. When resuming from + * another thread it will *not* switch threads. When in lazy mode, [TestCoroutineDispatcher] will enqueue all + * dispatches and whatever thread calls an [advanceUntilIdle], [advanceTimeBy], or [runCurrent] will continue execution. + * * @see DelayController */ @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0 -public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController { +public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayController, IdleWaiter { private var dispatchImmediately = true set(value) { field = value @@ -44,10 +51,13 @@ public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayControl // Storing time in nanoseconds internally. private val _time = atomic(0L) + private val waitLock = Channel(capacity = 1) + /** @suppress */ override fun dispatch(context: CoroutineContext, block: Runnable) { if (dispatchImmediately) { block.run() + unpark() } else { post(block) } @@ -79,14 +89,18 @@ public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayControl return "TestCoroutineDispatcher[currentTime=${currentTime}ms, queued=${queue.size}]" } - private fun post(block: Runnable) = + private fun post(block: Runnable) { queue.addLast(TimedRunnable(block, _counter.getAndIncrement())) + unpark() + } - private fun postDelayed(block: Runnable, delayTime: Long) = - TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime)) - .also { - queue.addLast(it) - } + private fun postDelayed(block: Runnable, delayTime: Long): TimedRunnable { + return TimedRunnable(block, _counter.getAndIncrement(), safePlus(currentTime, delayTime)) + .also { + queue.addLast(it) + unpark() + } + } private fun safePlus(currentTime: Long, delayTime: Long): Long { check(delayTime >= 0) @@ -132,11 +146,14 @@ public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayControl val next = queue.peek() ?: break advanceUntilTime(next.time) } + return currentTime - oldTime } /** @suppress */ - override fun runCurrent() = doActionsUntil(currentTime) + override fun runCurrent() { + doActionsUntil(currentTime) + } /** @suppress */ override suspend fun pauseDispatcher(block: suspend () -> Unit) { @@ -161,6 +178,7 @@ public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayControl /** @suppress */ override fun cleanupTestCoroutines() { + unpark() // process any pending cancellations or completions, but don't advance time doActionsUntil(currentTime) @@ -181,6 +199,14 @@ public class TestCoroutineDispatcher: CoroutineDispatcher(), Delay, DelayControl ) } } + + override suspend fun suspendUntilNextDispatch() { + waitLock.receive() + } + + private fun unpark() { + waitLock.offer(Unit) + } } /** @@ -212,4 +238,21 @@ private class TimedRunnable( } override fun toString() = "TimedRunnable(time=$time, run=$runnable)" -} \ No newline at end of file +} + +/** + * Alternative implementations of [TestCoroutineDispatcher] must implement this interface in order to be supported by + * [runBlockingTest]. + * + * This interface allows external code to suspend itself until the next dispatch is received. This is similar to park in + * a normal event loop, but doesn't require that [TestCoroutineDispatcher] block a thread while parked. + */ +interface IdleWaiter { + /** + * Attempt to suspend until the next dispatch is received. + * + * This method may resume immediately if any dispatch was received since the last time it was called. This ensures + * that dispatches won't be dropped if they happen just before calling [suspendUntilNextDispatch]. + */ + public suspend fun suspendUntilNextDispatch() +} diff --git a/kotlinx-coroutines-test/test/TestRunBlockingOrderTest.kt b/kotlinx-coroutines-test/test/TestRunBlockingOrderTest.kt index 0013a654a6..08b1665d6b 100644 --- a/kotlinx-coroutines-test/test/TestRunBlockingOrderTest.kt +++ b/kotlinx-coroutines-test/test/TestRunBlockingOrderTest.kt @@ -4,9 +4,12 @@ package kotlinx.coroutines.test +import junit.framework.TestCase.assertEquals import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* import org.junit.* -import kotlin.coroutines.* +import java.util.concurrent.Executors +import kotlin.concurrent.thread class TestRunBlockingOrderTest : TestBase() { @Test @@ -68,6 +71,57 @@ class TestRunBlockingOrderTest : TestBase() { expect(2) } + @Test + fun testNewThread_inSuspendCancellableCoroutine() = runBlockingTest { + expect(1) + suspendCancellableCoroutine { cont -> + expect(2) + thread { + expect(3) + cont.resume(Unit) { Unit } + } + } + finish(4) + } + + @Test + fun testWithDelayInOtherDispatcher_passesWhenDelayIsShort() = runBlockingTest { + expect(1) + withContext(Dispatchers.IO) { + delay(1) + expect(2) + } + finish(3) + } + + @Test + fun testThrows_throws() { + val expected = IllegalStateException("expected") + val result = runCatching { + expect(1) + runBlockingTest { + expect(2) + throw expected + } + } + finish(3) + assertEquals(expected, result.exceptionOrNull()) + } + + @Test + fun testSuspendForever_fails() { + val uncompleted = CompletableDeferred() + val result = runCatching { + expect(1) + runBlockingTest(waitForOtherDispatchers = 0L) { + expect(2) + uncompleted.await() + } + } + finish(3) + assertEquals(true, result.isFailure) + } + @Test fun testAdvanceUntilIdle_inRunBlocking() = runBlockingTest { expect(1) @@ -76,4 +130,186 @@ class TestRunBlockingOrderTest : TestBase() { } finish(2) } + + @Test + fun testComplexDispatchFromOtherDispatchersOverTime_completes() { + val otherDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher() + + val max = 10 + + val numbersFromOtherDispatcherWithDelays = flow { + var current = 0 + while (current < max) { + delay(1) + emit(++current) + } + }.flowOn(otherDispatcher) + + try { + runBlockingTest { + numbersFromOtherDispatcherWithDelays.collect { value -> + expect(value) + } + expect(max + 1) + } + } finally { + otherDispatcher.close() + } + finish(max + 2) + } + + @Test + fun testComplexDispatchFromOtherDispatchersOverTime_withPasuedTestDispatcher_completes() { + val otherDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher() + + val max = 10 + + val numbersFromOtherDispatcherWithDelays = flow { for(x in 1..max) { emit(x) } } + .buffer(0) + .delayEach(1) + .flowOn(otherDispatcher) + + otherDispatcher.use { + runBlockingTest { + pauseDispatcher() + numbersFromOtherDispatcherWithDelays.collect { value -> + expect(value) + } + expect(max + 1) + } + } + finish(max + 2) + } + + @Test + fun testDispatchFromOtherDispatch_triggersInternalDispatch() { + val otherDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher() + + val numbersFromOtherDispatcherWithDelays = flow { emit(1) } + .delayEach(1) + .buffer(0) + .flowOn(otherDispatcher) + + otherDispatcher.use { + runBlockingTest { + numbersFromOtherDispatcherWithDelays.collect { value -> + expect(value) + launch { + expect(2) + } + } + expect(3) + } + } + finish(4) + } + + @Test + fun testDispatchFromOtherDispatch_triggersInternalDispatch_withDelay() { + val otherDispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher() + + val max = 10 + + val numbersFromOtherDispatcherWithDelays = flow { for(x in 1..max) { emit(x)} } + .filter { it % 2 == 1 } + .delayEach(1) + .buffer(0) + .flowOn(otherDispatcher) + + otherDispatcher.use { + runBlockingTest { + numbersFromOtherDispatcherWithDelays.collect { value -> + expect(value) + delay(1) + expect (value + 1) + } + delay(1) + expect(max + 1) + } + } + finish(max + 2) + } + + @Test + fun whenWaitConfig_timesOut_getExceptionWithMessage() { + expect(1) + val uncompleted = CompletableDeferred() + val result = runCatching { + runBlockingTest(waitForOtherDispatchers = 1L) { + withContext(Dispatchers.IO) { + finish(2) + uncompleted.await() + } + } + } + val hasDetailedError = result.exceptionOrNull()?.message?.contains("may be empty") + assertEquals(true, hasDetailedError) + uncompleted.cancel() + } + + @Test + fun whenCoroutineStartedInScope_doesntLeakOnAnotherDispatcher() { + var job: Job? = null + runBlockingTest { + expect(1) + job = launch(Dispatchers.IO) { + delay(1) + expect(3) + } + expect(2) + } + assertEquals(true, job?.isCompleted) + finish(4) + } + + @Test + fun whenDispatcherPaused_runBlocking_dispatchesToTestThread() { + val thread = Thread.currentThread() + runBlockingTest { + pauseDispatcher() + withContext(Dispatchers.IO) { + expect(1) + delay(1) + expect(2) + } + assertEquals(thread, Thread.currentThread()) + finish(3) + } + } + + @Test + fun whenDispatcherResumed_runBlocking_dispatchesImmediatelyOnIO() { + var thread: Thread? = null + runBlockingTest { + resumeDispatcher() + withContext(Dispatchers.IO) { + expect(1) + delay(1) + expect(2) + thread = Thread.currentThread() + } + assertEquals(thread, Thread.currentThread()) + finish(3) + } + } + + @Test + fun whenDispatcherRunning_doesntProgressDelays_inLaunchBody() { + var state = 0 + fun CoroutineScope.subject() = launch { + state = 1 + delay(1000) + state = 2 + } + + runBlockingTest { + subject() + + assertEquals(1, state) + + advanceTimeBy(1000) + + assertEquals(2, state) + } + } } diff --git a/kotlinx-coroutines-test/test/TestRunBlockingTest.kt b/kotlinx-coroutines-test/test/TestRunBlockingTest.kt index e0c7091505..8cfec254cd 100644 --- a/kotlinx-coroutines-test/test/TestRunBlockingTest.kt +++ b/kotlinx-coroutines-test/test/TestRunBlockingTest.kt @@ -5,6 +5,7 @@ package kotlinx.coroutines.test import kotlinx.coroutines.* +import java.lang.IllegalStateException import kotlin.coroutines.* import kotlin.test.* @@ -129,7 +130,6 @@ class TestRunBlockingTest { @Test fun whenUsingTimeout_inAsync_doesNotTriggerWhenNotDelayed() = runBlockingTest { - val testScope = this val deferred = async { withTimeout(SLOW) { delay(0) @@ -187,13 +187,13 @@ class TestRunBlockingTest { assertRunsFast { job.join() - throw job.getCancellationException().cause ?: assertFails { "expected exception" } + throw job.getCancellationException().cause ?: IllegalStateException("expected exception") } } @Test(expected = IllegalArgumentException::class) fun throwingException__inAsync_throws() = runBlockingTest { - val deferred = async { + val deferred : Deferred = async { delay(SLOW) throw IllegalArgumentException("Test") } @@ -274,7 +274,7 @@ class TestRunBlockingTest { } @Test(expected = UncompletedCoroutinesError::class) - fun whenACoroutineLeaks_errorIsThrown() = runBlockingTest { + fun whenACoroutineLeaks_errorIsThrown() = runBlockingTest(waitForOtherDispatchers = 0L) { val uncompleted = CompletableDeferred() launch { uncompleted.await() @@ -342,7 +342,7 @@ class TestRunBlockingTest { fun testWithTestContextThrowingAnAssertionError() = runBlockingTest { val expectedError = IllegalAccessError("hello") - val job = launch { + launch { throw expectedError }