Skip to content

Commit

Permalink
Restore timeouts using virtual time
Browse files Browse the repository at this point in the history
  • Loading branch information
jingibus committed Jul 27, 2022
1 parent 5ea5ba3 commit 01dbefa
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 50 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ kotlin {
commonMain {
dependencies {
api "org.jetbrains.kotlinx:kotlinx-coroutines-core:${versions.coroutines}"
implementation 'org.jetbrains.kotlin:kotlin-test'
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-test:${versions.coroutines}"
}
}
commonTest {
Expand Down
23 changes: 17 additions & 6 deletions src/commonMain/kotlin/app/cash/turbine/Turbine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
*/
package app.cash.turbine

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.channels.ChannelResult
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withContext

internal const val debug = false

Expand Down Expand Up @@ -88,12 +91,20 @@ public operator fun <T> Turbine<T>.plusAssign(value: T) { add(value) }
/**
* Construct a standalone [Turbine].
*/
public fun <T> Turbine(): Turbine<T> = TurbineImpl()
public fun <T> Turbine(timeoutMs: Long? = null): Turbine<T> = TurbineImpl(timeoutMs = timeoutMs)

internal class TurbineImpl<T>(
channel: Channel<T> = Channel(UNLIMITED),
private val job: Job? = null,
private val timeoutMs: Long? = null,
) : Turbine<T> {
private suspend fun <T> withTurbineTimeout(block: suspend CoroutineScope.() -> T): T {
return if (timeoutMs != null) {
withTurbineTimeout(timeoutMs, block)
} else coroutineScope {
block()
}
}

private val channel = object : Channel<T> by channel {
override fun tryReceive(): ChannelResult<T> {
Expand Down Expand Up @@ -169,15 +180,15 @@ internal class TurbineImpl<T>(

override fun expectMostRecentItem(): T = channel.expectMostRecentItem()

override suspend fun awaitEvent(): Event<T> = channel.awaitEvent()
override suspend fun awaitEvent(): Event<T> = withTurbineTimeout { channel.awaitEvent() }

override suspend fun awaitItem(): T = channel.awaitItem()
override suspend fun awaitItem(): T = withTurbineTimeout { channel.awaitItem() }

override suspend fun skipItems(count: Int) = channel.skipItems(count)
override suspend fun skipItems(count: Int) = withTurbineTimeout { channel.skipItems(count) }

override suspend fun awaitComplete() = channel.awaitComplete()
override suspend fun awaitComplete() = withTurbineTimeout { channel.awaitComplete() }

override suspend fun awaitError(): Throwable = channel.awaitError()
override suspend fun awaitError(): Throwable = withTurbineTimeout { channel.awaitError() }

override fun ensureAllEventsConsumed() {
if (ignoreRemainingEvents) return
Expand Down
52 changes: 46 additions & 6 deletions src/commonMain/kotlin/app/cash/turbine/channel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
*/
package app.cash.turbine

import kotlin.coroutines.coroutineContext
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.channels.ChannelResult
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.test.TestCoroutineScheduler

/**
* Returns the most recent item that has already been received.
Expand Down Expand Up @@ -64,16 +70,44 @@ public fun <T> ReceiveChannel<T>.expectNoEvents() {
*
* This function will always return a terminal event on a closed [ReceiveChannel].
*/
public suspend fun <T> ReceiveChannel<T>.awaitEvent(): Event<T> =
try {
Event.Item(receive())
@OptIn(ExperimentalCoroutinesApi::class)
public suspend fun <T> ReceiveChannel<T>.awaitEvent(): Event<T> {
val timeoutMs = contextTimeout()
val testScheduler = coroutineContext[TestCoroutineScheduler]
return try {
withTimeout(timeoutMs) {
val item = if (testScheduler == null) {
// With no test scheduler, let receive() expire the timeout. This will use wallclock time.
receive()
} else {
// *With* a test scheduler, we must poll and nudge the clock
// until some kind of result is produced.
val value: T
while (true) {
val result = tryReceive()
if (result.isFailure && !result.isClosed) {
delay(timeoutMs / 10)
} else if (result.isFailure && result.isClosed) {
throw (result.exceptionOrNull() ?: ClosedReceiveChannelException(null))
} else {
value = result.getOrThrow()
break
}
}
value
}
Event.Item(item)
}
} catch (e: TimeoutCancellationException) {
throw AssertionError("No value produced in ${timeoutMs}ms")
} catch (e: CancellationException) {
throw e
} catch (e: ClosedReceiveChannelException) {
Event.Complete
} catch (e: Exception) {
Event.Error(e)
}
}

/**
* Assert that the next event received was non-null and return it.
Expand Down Expand Up @@ -130,11 +164,15 @@ public fun <T> ReceiveChannel<T>.takeError(): Throwable {
*
* @throws AssertionError if the next event was completion or an error.
*/
public suspend fun <T> ReceiveChannel<T>.awaitItem(): T =
public suspend fun <T> ReceiveChannel<T>.awaitItem(): T = try {
when (val result = awaitEvent()) {
is Event.Item -> result.value
else -> unexpectedEvent(result, "item")
}
} catch (e: Exception) {
println("Caught it! $e")
throw e
}

/**
* Assert that [count] item events were received and ignore them.
Expand All @@ -149,7 +187,8 @@ public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int) {
val cause = (event as? Event.Error)?.throwable
throw TurbineAssertionError("Expected $count items but got $index items and $event", cause)
}
is Event.Item<T> -> { /* Success */ }
is Event.Item<T> -> { /* Success */
}
}
}
}
Expand All @@ -173,7 +212,7 @@ public suspend fun <T> ReceiveChannel<T>.awaitComplete() {
*
* @throws AssertionError if the next event was an item or completion.
*/
public suspend fun <T> ReceiveChannel<T>.awaitError(): Throwable {
public suspend fun <T> ReceiveChannel<T>.awaitError(): Throwable {
val event = awaitEvent()
return (event as? Event.Error)?.throwable
?: unexpectedEvent(event, "error")
Expand All @@ -187,6 +226,7 @@ internal fun <T> ChannelResult<T>.toEvent(): Event<T>? {
else if (isClosed) Event.Complete
else null
}

private fun <T> ChannelResult<T>.unexpectedResult(expected: String): Nothing = unexpectedEvent(toEvent(), expected)

private fun unexpectedEvent(event: Event<*>?, expected: String): Nothing {
Expand Down
27 changes: 27 additions & 0 deletions src/commonMain/kotlin/app/cash/turbine/coroutines.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@
*/
package app.cash.turbine

import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext

public const val DEFAULT_TIMEOUT_MS: Long = 1000

/**
* Sets the timeout for all [Turbine] instances within this context.
*/
public suspend fun <T> withTurbineTimeout(timeoutMs: Long, block: suspend CoroutineScope.() -> T): T {
return withContext(TurbineTimeoutElement(timeoutMs)) {
block()
}
}

/**
* Invoke this method to throw an error when your method is not being called by a suspend fun.
*
Expand Down Expand Up @@ -44,3 +60,14 @@ internal fun assertCallingContextIsNotSuspended() {
error("Calling context is suspending; use a suspending method instead")
}
}

internal class TurbineTimeoutElement(
val timeout: Long,
) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<TurbineTimeoutElement>

override val key: CoroutineContext.Key<*> = Key
}
internal suspend fun contextTimeout(): Long {
return currentCoroutineContext()[TurbineTimeoutElement.Key]?.timeout ?: DEFAULT_TIMEOUT_MS
}
55 changes: 17 additions & 38 deletions src/commonMain/kotlin/app/cash/turbine/flow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,11 @@ import kotlinx.coroutines.launch
* }
* ```
*/
@Deprecated("Timeout parameter removed. Use runTest which has a timeout or wrap in withTimeout.",
ReplaceWith("this.test(validate)"),
DeprecationLevel.ERROR,
)
@Suppress("UNUSED_PARAMETER")
public suspend fun <T> Flow<T>.test(
timeoutMs: Long,
timeout: Duration,
validate: suspend ReceiveTurbine<T>.() -> Unit,
) {
test(validate)
}

/**
* Terminal flow operator that collects events from given flow and allows the [validate] lambda to
* consume and assert properties on them in order. If any exception occurs during validation the
* exception is rethrown from this method.
*
* ```kotlin
* flowOf("one", "two").test {
* assertEquals("one", expectItem())
* assertEquals("two", expectItem())
* expectComplete()
* }
* ```
*/
@Deprecated("Timeout parameter removed. Use runTest which has a timeout or wrap in withTimeout.",
ReplaceWith("this.test(validate)"),
DeprecationLevel.ERROR,
)
@Suppress("UNUSED_PARAMETER")
public suspend fun <T> Flow<T>.test(
timeout: Duration = 1.seconds,
validate: suspend ReceiveTurbine<T>.() -> Unit,
) {
test(validate)
test(timeoutMs = timeout.inWholeMilliseconds, validate)
}

/**
Expand All @@ -92,10 +62,11 @@ public suspend fun <T> Flow<T>.test(
* ```
*/
public suspend fun <T> Flow<T>.test(
timeoutMs: Long? = null,
validate: suspend ReceiveTurbine<T>.() -> Unit,
) {
coroutineScope {
collectTurbineIn(this).apply {
collectTurbineIn(this, timeoutMs).apply {
validate()
cancel()
ensureAllEventsConsumed()
Expand All @@ -118,8 +89,8 @@ public suspend fun <T> Flow<T>.test(
* Unlike [test] which automatically cancels the flow at the end of the lambda, the returned
* [ReceiveTurbine] must either consume a terminal event (complete or error) or be explicitly canceled.
*/
public fun <T> Flow<T>.testIn(scope: CoroutineScope): ReceiveTurbine<T> {
val turbine = collectTurbineIn(scope)
public fun <T> Flow<T>.testIn(scope: CoroutineScope, timeoutMs: Long? = null): ReceiveTurbine<T> {
val turbine = collectTurbineIn(scope, timeoutMs)

scope.coroutineContext.job.invokeOnCompletion { exception ->
if (debug) println("Scope ending ${exception ?: ""}")
Expand All @@ -133,14 +104,22 @@ public fun <T> Flow<T>.testIn(scope: CoroutineScope): ReceiveTurbine<T> {
return turbine
}

private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope): Turbine<T> {
private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeoutMs: Long?): Turbine<T> {
lateinit var outputBox: Channel<T>

val job = scope.launch(start = UNDISPATCHED) {
outputBox = collectIntoChannel(this)

val block: suspend CoroutineScope.() -> Unit = {
outputBox = collectIntoChannel(this)
}
if (timeoutMs != null) {
withTurbineTimeout(timeoutMs, block)
} else {
block()
}
}

return TurbineImpl(outputBox, job)
return TurbineImpl(outputBox, job, timeoutMs = null)
}

internal fun <T> Flow<T>.collectIntoChannel(scope: CoroutineScope): Channel<T> {
Expand Down
37 changes: 37 additions & 0 deletions src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flow
Expand All @@ -42,6 +47,7 @@ import kotlinx.coroutines.test.TestCoroutineScheduler
import kotlinx.coroutines.test.TestDispatcher
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withTimeoutOrNull

class ChannelTest {
@Test
Expand Down Expand Up @@ -227,6 +233,37 @@ class ChannelTest {
assertEquals("Expected error but found Complete", actual.message)
}

@Test fun failsOnDefaultTimeout() = runTest {
assertFailsWith<AssertionError> {
coroutineScope {
neverFlow().collectIntoChannel(this).awaitItem()
}
}
}

@Test fun awaitHonorsCoroutineContextTimeoutNoTimeout() = runTest {
withTurbineTimeout(5000) {
val job = launch {
neverFlow().collectIntoChannel(this).awaitItem()
}

delay(3000)
job.cancel()
}
}

@Test fun awaitHonorsCoroutineContextTimeoutTimeout() = runTest {
assertFailsWith<AssertionError> {
withTurbineTimeout(5000) {
launch {
neverFlow().collectIntoChannel(this).awaitItem()
}

delay(5000)
}
}
}

@Test fun takeItem() = withTestScope {
val item = Any()
val channel = flowOf(item).collectIntoChannel(this)
Expand Down

0 comments on commit 01dbefa

Please sign in to comment.