Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add name API #158

Merged
merged 2 commits into from Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 19 additions & 13 deletions src/commonMain/kotlin/app/cash/turbine/Turbine.kt
Expand Up @@ -91,14 +91,19 @@ public operator fun <T> Turbine<T>.plusAssign(value: T) { add(value) }
*
* @param timeout If non-null, overrides the current Turbine timeout for this [Turbine]. See also:
* [withTurbineTimeout].
* @param name If non-null, name is added to any exceptions thrown to help identify which [Turbine] failed.
*/
@Suppress("FunctionName") // Interface constructor pattern.
public fun <T> Turbine(timeout: Duration? = null): Turbine<T> = ChannelTurbine(timeout = timeout)
public fun <T> Turbine(
timeout: Duration? = null,
name: String? = null,
): Turbine<T> = ChannelTurbine(timeout = timeout, name = name)

internal class ChannelTurbine<T>(
channel: Channel<T> = Channel(UNLIMITED),
private val job: Job? = null,
private val timeout: Duration?,
private val name: String?,
) : Turbine<T> {
private suspend fun <T> withTurbineTimeout(block: suspend () -> T): T {
return if (timeout != null) {
Expand Down Expand Up @@ -145,13 +150,13 @@ internal class ChannelTurbine<T>(
job?.cancel()
}

override fun takeEvent(): Event<T> = channel.takeEvent()
override fun takeEvent(): Event<T> = channel.takeEvent(name = name)

override fun takeItem(): T = channel.takeItem()
override fun takeItem(): T = channel.takeItem(name = name)

override fun takeComplete() = channel.takeComplete()
override fun takeComplete() = channel.takeComplete(name = name)

override fun takeError(): Throwable = channel.takeError()
override fun takeError(): Throwable = channel.takeError(name = name)

private var ignoreTerminalEvents = false
private var ignoreRemainingEvents = false
Expand All @@ -176,20 +181,20 @@ internal class ChannelTurbine<T>(
}

override fun expectNoEvents() {
channel.expectNoEvents()
channel.expectNoEvents(name = name)
}

override fun expectMostRecentItem(): T = channel.expectMostRecentItem()
override fun expectMostRecentItem(): T = channel.expectMostRecentItem(name = name)

override suspend fun awaitEvent(): Event<T> = withTurbineTimeout { channel.awaitEvent() }
override suspend fun awaitEvent(): Event<T> = withTurbineTimeout { channel.awaitEvent(name = name) }

override suspend fun awaitItem(): T = withTurbineTimeout { channel.awaitItem() }
override suspend fun awaitItem(): T = withTurbineTimeout { channel.awaitItem(name = name) }

override suspend fun skipItems(count: Int) = withTurbineTimeout { channel.skipItems(count) }
override suspend fun skipItems(count: Int) = withTurbineTimeout { channel.skipItems(count, name) }

override suspend fun awaitComplete() = withTurbineTimeout { channel.awaitComplete() }
override suspend fun awaitComplete() = withTurbineTimeout { channel.awaitComplete(name = name) }

override suspend fun awaitError(): Throwable = withTurbineTimeout { channel.awaitError() }
override suspend fun awaitError(): Throwable = withTurbineTimeout { channel.awaitError(name = name) }

override fun ensureAllEventsConsumed() {
if (ignoreRemainingEvents) return
Expand All @@ -209,7 +214,8 @@ internal class ChannelTurbine<T>(
if (unconsumed.isNotEmpty()) {
throw TurbineAssertionError(
buildString {
append("Unconsumed events found:")
append("Unconsumed events found".qualifiedBy(name))
append(":")
for (event in unconsumed) {
append("\n - $event")
}
Expand Down
62 changes: 35 additions & 27 deletions src/commonMain/kotlin/app/cash/turbine/channel.kt
Expand Up @@ -42,7 +42,7 @@ import kotlinx.coroutines.withTimeout
*
* @throws AssertionError if no item was emitted.
*/
public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {
public fun <T> ReceiveChannel<T>.expectMostRecentItem(name: String? = null): T {
var previous: ChannelResult<T>? = null
while (true) {
val current = tryReceive()
Expand All @@ -55,7 +55,7 @@ public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {

if (previous?.isSuccess == true) return previous.getOrThrow()

throw AssertionError("No item was found")
throw AssertionError("No item was found".qualifiedBy(name))
}

/**
Expand All @@ -66,9 +66,9 @@ public fun <T> ReceiveChannel<T>.expectMostRecentItem(): T {
*
* @throws AssertionError if unconsumed events are found.
*/
public fun <T> ReceiveChannel<T>.expectNoEvents() {
public fun <T> ReceiveChannel<T>.expectNoEvents(name: String? = null) {
val result = tryReceive()
if (!result.isFailure) result.unexpectedResult("no events")
if (!result.isFailure) result.unexpectedResult(name, "no events")
}

/**
Expand All @@ -77,17 +77,17 @@ public fun <T> ReceiveChannel<T>.expectNoEvents() {
*
* This function will always return a terminal event on a closed [ReceiveChannel].
*/
public suspend fun <T> ReceiveChannel<T>.awaitEvent(): Event<T> {
public suspend fun <T> ReceiveChannel<T>.awaitEvent(name: String? = null): Event<T> {
val timeout = contextTimeout()
return try {
withAppropriateTimeout(timeout) {
val item = receive()
Event.Item(item)
}
} catch (e: TimeoutCancellationException) {
throw AssertionError("No value produced in $timeout")
throw AssertionError("No ${"value produced".qualifiedBy(name)} in $timeout")
} catch (e: TurbineTimeoutCancellationException) {
throw AssertionError("No value produced in $timeout")
throw AssertionError("No ${"value produced".qualifiedBy(name)} in $timeout")
} catch (e: CancellationException) {
throw e
} catch (e: ClosedReceiveChannelException) {
Expand Down Expand Up @@ -139,10 +139,10 @@ internal class TurbineTimeoutCancellationException internal constructor(
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeEvent(): Event<T> {
public fun <T> ReceiveChannel<T>.takeEvent(name: String? = null): Event<T> {
assertCallingContextIsNotSuspended()
return takeEventUnsafe()
?: unexpectedEvent(null, "an event")
?: unexpectedEvent(name, null, "an event")
}

internal fun <T> ReceiveChannel<T>.takeEventUnsafe(): Event<T>? {
Expand All @@ -155,9 +155,9 @@ internal fun <T> ReceiveChannel<T>.takeEventUnsafe(): Event<T>? {
*
* @throws AssertionError if the next event was completion or an error, or no event.
*/
public fun <T> ReceiveChannel<T>.takeItem(): T {
public fun <T> ReceiveChannel<T>.takeItem(name: String? = null): T {
val event = takeEvent()
return (event as? Event.Item)?.value ?: unexpectedEvent(event, "item")
return (event as? Event.Item)?.value ?: unexpectedEvent(name, event, "item")
}

/**
Expand All @@ -166,9 +166,9 @@ public fun <T> ReceiveChannel<T>.takeItem(): T {
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeComplete() {
public fun <T> ReceiveChannel<T>.takeComplete(name: String? = null) {
val event = takeEvent()
if (event !is Event.Complete) unexpectedEvent(event, "complete")
if (event !is Event.Complete) unexpectedEvent(name, event, "complete")
}

/**
Expand All @@ -177,9 +177,9 @@ public fun <T> ReceiveChannel<T>.takeComplete() {
*
* @throws AssertionError if the next event was completion or an error.
*/
public fun <T> ReceiveChannel<T>.takeError(): Throwable {
public fun <T> ReceiveChannel<T>.takeError(name: String? = null): Throwable {
val event = takeEvent()
return (event as? Event.Error)?.throwable ?: unexpectedEvent(event, "error")
return (event as? Event.Error)?.throwable ?: unexpectedEvent(name, event, "error")
}

/**
Expand All @@ -188,10 +188,10 @@ public fun <T> ReceiveChannel<T>.takeError(): Throwable {
*
* @throws AssertionError if the next event was completion or an error.
*/
public suspend fun <T> ReceiveChannel<T>.awaitItem(): T =
when (val result = awaitEvent()) {
public suspend fun <T> ReceiveChannel<T>.awaitItem(name: String? = null): T =
when (val result = awaitEvent(name = name)) {
is Event.Item -> result.value
else -> unexpectedEvent(result, "item")
else -> unexpectedEvent(name, result, "item")
}

/**
Expand All @@ -200,12 +200,12 @@ public suspend fun <T> ReceiveChannel<T>.awaitItem(): T =
*
* @throws AssertionError if one of the events was completion or an error.
*/
public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int) {
public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int, name: String? = null) {
repeat(count) { index ->
when (val event = awaitEvent()) {
Event.Complete, is Event.Error -> {
val cause = (event as? Event.Error)?.throwable
throw TurbineAssertionError("Expected $count items but got $index items and $event", cause)
throw TurbineAssertionError("Expected $count ${"items".qualifiedBy(name)} but got $index items and $event", cause)
}
is Event.Item<T> -> {
// Success
Expand All @@ -220,10 +220,10 @@ public suspend fun <T> ReceiveChannel<T>.skipItems(count: Int) {
*
* @throws AssertionError if the next event was an item or an error.
*/
public suspend fun <T> ReceiveChannel<T>.awaitComplete() {
public suspend fun <T> ReceiveChannel<T>.awaitComplete(name: String? = null) {
val event = awaitEvent()
if (event != Event.Complete) {
unexpectedEvent(event, "complete")
unexpectedEvent(name, event, "complete")
}
}

Expand All @@ -233,10 +233,10 @@ public suspend fun <T> ReceiveChannel<T>.awaitComplete() {
*
* @throws AssertionError if the next event was an item or completion.
*/
public suspend fun <T> ReceiveChannel<T>.awaitError(): Throwable {
public suspend fun <T> ReceiveChannel<T>.awaitError(name: String? = null): Throwable {
val event = awaitEvent()
return (event as? Event.Error)?.throwable
?: unexpectedEvent(event, "error")
?: unexpectedEvent(name, event, "error")
}

internal fun <T> ChannelResult<T>.toEvent(): Event<T>? {
Expand All @@ -249,10 +249,18 @@ internal fun <T> ChannelResult<T>.toEvent(): Event<T>? {
}
}

private fun <T> ChannelResult<T>.unexpectedResult(expected: String): Nothing = unexpectedEvent(toEvent(), expected)
private fun <T> ChannelResult<T>.unexpectedResult(name: String?, expected: String): Nothing =
unexpectedEvent(name, toEvent(), expected)

private fun unexpectedEvent(event: Event<*>?, expected: String): Nothing {
private fun unexpectedEvent(name: String?, event: Event<*>?, expected: String): Nothing {
val cause = (event as? Event.Error)?.throwable
val eventAsString = event?.toString() ?: "no items"
throw TurbineAssertionError("Expected $expected but found $eventAsString", cause)
throw TurbineAssertionError("Expected ${expected.qualifiedBy(name)} but found $eventAsString", cause)
}

internal fun String.qualifiedBy(name: String?) =
if (name == null) {
this
} else {
"$this for $name"
}
15 changes: 10 additions & 5 deletions src/commonMain/kotlin/app/cash/turbine/flow.kt
Expand Up @@ -48,10 +48,11 @@ import kotlinx.coroutines.test.UnconfinedTestDispatcher
*/
public suspend fun <T> Flow<T>.test(
timeout: Duration? = null,
name: String? = null,
validate: suspend ReceiveTurbine<T>.() -> Unit,
) {
coroutineScope {
collectTurbineIn(this, null).apply {
collectTurbineIn(this, null, name).apply {
if (timeout != null) {
withTurbineTimeout(timeout) {
validate()
Expand Down Expand Up @@ -83,13 +84,17 @@ public suspend fun <T> Flow<T>.test(
* @param timeout If non-null, overrides the current Turbine timeout for this [Turbine]. See also:
* [withTurbineTimeout].
*/
public fun <T> Flow<T>.testIn(scope: CoroutineScope, timeout: Duration? = null): ReceiveTurbine<T> {
public fun <T> Flow<T>.testIn(
scope: CoroutineScope,
timeout: Duration? = null,
name: String? = null,
): ReceiveTurbine<T> {
if (timeout != null) {
// Eager check to throw early rather than in a subsequent 'await' call.
checkTimeout(timeout)
}

val turbine = collectTurbineIn(scope, timeout)
val turbine = collectTurbineIn(scope, timeout, name)

scope.coroutineContext.job.invokeOnCompletion { exception ->
if (debug) println("Scope ending ${exception ?: ""}")
Expand All @@ -104,7 +109,7 @@ public fun <T> Flow<T>.testIn(scope: CoroutineScope, timeout: Duration? = null):
}

@OptIn(ExperimentalCoroutinesApi::class) // New kotlinx.coroutines test APIs are not stable 😬
private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duration?): Turbine<T> {
private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duration?, name: String?): Turbine<T> {
lateinit var channel: Channel<T>

// Use test-specific unconfined if test scheduler is in use to inherit its virtual time.
Expand All @@ -116,7 +121,7 @@ private fun <T> Flow<T>.collectTurbineIn(scope: CoroutineScope, timeout: Duratio
channel = collectIntoChannel(this)
}

return ChannelTurbine(channel, job, timeout)
return ChannelTurbine(channel, job, timeout, name)
}

internal fun <T> Flow<T>.collectIntoChannel(scope: CoroutineScope): Channel<T> {
Expand Down
63 changes: 63 additions & 0 deletions src/commonTest/kotlin/app/cash/turbine/ChannelTest.kt
Expand Up @@ -108,6 +108,14 @@ class ChannelTest {
assertEquals(3, channel.awaitItem())
}

@Test fun skipItemsThrowsOnComplete() = runTest {
val channel = flowOf(1, 2).collectIntoChannel(this)
val message = assertFailsWith<AssertionError> {
channel.skipItems(3)
}.message
assertEquals("Expected 3 items but got 2 items and Complete", message)
}

@Test fun expectErrorOnCompletionBeforeAllItemsWereSkipped() = runTest {
val channel = flowOf(1).collectIntoChannel(this)
assertFailsWith<AssertionError> {
Expand Down Expand Up @@ -286,6 +294,61 @@ class ChannelTest {
assertSame(error, actual.cause)
}

@Test
fun expectMostRecentItemButNoItemWasFoundThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
val channel = emptyFlow<Any>().collectIntoChannel(this)
channel.expectMostRecentItem(name = "empty flow")
}
assertEquals("No item was found for empty flow", actual.message)
}

@Test fun awaitItemButWasCloseThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
emptyFlow<Unit>().collectIntoChannel(this).awaitItem(name = "closed flow")
}
assertEquals("Expected item for closed flow but found Complete", actual.message)
}

@Test fun awaitCompleteButWasItemThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
flowOf("item!").collectIntoChannel(this)
.awaitComplete(name = "item flow")
}
assertEquals("Expected complete for item flow but found Item(item!)", actual.message)
}

@Test fun awaitErrorButWasItemThrowsWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
flowOf("item!").collectIntoChannel(this).awaitError(name = "item flow")
}
assertEquals("Expected error for item flow but found Item(item!)", actual.message)
}

@Test fun awaitHonorsCoroutineContextTimeoutTimeoutWithName() = runTest {
val actual = assertFailsWith<AssertionError> {
withTurbineTimeout(10.milliseconds) {
neverFlow().collectIntoChannel(this).awaitItem(name = "never flow")
}
}
assertEquals("No value produced for never flow in 10ms", actual.message)
}

@Test fun takeItemButWasCloseThrowsWithName() = withTestScope {
val actual = assertFailsWith<AssertionError> {
emptyFlow<Unit>().collectIntoChannel(this).takeItem(name = "empty flow")
}
assertEquals("Expected item for empty flow but found Complete", actual.message)
}

@Test fun skipItemsThrowsOnCompleteWithName() = runTest {
val channel = flowOf(1, 2).collectIntoChannel(this)
val message = assertFailsWith<AssertionError> {
channel.skipItems(3, name = "two item channel")
}.message
assertEquals("Expected 3 items for two item channel but got 2 items and Complete", message)
}

/**
* Used to run test code with a [TestScope], but still outside a suspending context.
*/
Expand Down