Skip to content

Commit

Permalink
Flow.transformWhile operator (#2066)
Browse files Browse the repository at this point in the history
Also, most flow-truncating operators are refactored via a common internal collectWhile operator that properly uses AbortFlowException and checks for its ownership, so that we don't have to look for bugs in interactions between all those operators (and zip, too, which is also flow-truncating). But `take` operator still users a custom highly-tuned implementation.

Fixes #2065

Co-authored-by: EdwarDDay <4127904+EdwarDDay@users.noreply.github.com>
Co-authored-by: Louis CAD <louis.cognault@gmail.com>
  • Loading branch information
3 people committed Jul 16, 2020
1 parent 5183b62 commit b42f986
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 35 deletions.
68 changes: 68 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/flow/TakeWhileBenchmark.kt
@@ -0,0 +1,68 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.internal.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.TimeUnit

@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class TakeWhileBenchmark {
@Param("1", "10", "100", "1000")
private var size: Int = 0

private suspend inline fun Flow<Long>.consume() =
filter { it % 2L != 0L }
.map { it * it }.count()

@Benchmark
fun baseline() = runBlocking<Int> {
(0L until size).asFlow().consume()
}

@Benchmark
fun takeWhileDirect() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().takeWhileDirect { it < size }.consume()
}

@Benchmark
fun takeWhileViaCollectWhile() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().takeWhileViaCollectWhile { it < size }.consume()
}

// Direct implementation by checking predicate and throwing AbortFlowException
private fun <T> Flow<T>.takeWhileDirect(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
try {
collect { value ->
if (predicate(value)) emit(value)
else throw AbortFlowException(this)
}
} catch (e: AbortFlowException) {
e.checkOwnership(owner = this)
}
}

// Essentially the same code, but reusing the logic via collectWhile function
private fun <T> Flow<T>.takeWhileViaCollectWhile(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
// This return is needed to work around a bug in JS BE: KT-39227
return@unsafeFlow collectWhile { value ->
if (predicate(value)) {
emit(value)
true
} else {
false
}
}
}
}
1 change: 1 addition & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Expand Up @@ -995,6 +995,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
public static synthetic fun toSet$default (Lkotlinx/coroutines/flow/Flow;Ljava/util/Set;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun transform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun transformLatest (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun transformWhile (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun unsafeTransform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun withIndex (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
public static final fun zip (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
Expand Down
7 changes: 4 additions & 3 deletions kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt
Expand Up @@ -19,10 +19,11 @@ import kotlin.jvm.*
/**
* Applies [transform] function to each value of the given flow.
*
* The receiver of the [transform] is [FlowCollector] and thus `transform` is a
* generic function that may transform emitted element, skip it or emit it multiple times.
* The receiver of the `transform` is [FlowCollector] and thus `transform` is a
* flexible function that may transform emitted element, skip it or emit it multiple times.
*
* This operator can be used as a building block for other operators, for example:
* This operator generalizes [filter] and [map] operators and
* can be used as a building block for other operators, for example:
*
* ```
* fun Flow<Int>.skipOddAndDuplicateEven(): Flow<Int> = transform { value ->
Expand Down
69 changes: 64 additions & 5 deletions kotlinx-coroutines-core/common/src/flow/operators/Limit.kt
Expand Up @@ -7,8 +7,10 @@

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.*
import kotlin.jvm.*
import kotlinx.coroutines.flow.flow as safeFlow
import kotlinx.coroutines.flow.internal.unsafeFlow as flow

/**
Expand Down Expand Up @@ -51,6 +53,10 @@ public fun <T> Flow<T>.take(count: Int): Flow<T> {
var consumed = 0
try {
collect { value ->
// Note: this for take is not written via collectWhile on purpose.
// It checks condition first and then makes a tail-call to either emit or emitAbort.
// This way normal execution does not require a state machine, only a termination (emitAbort).
// See "TakeBenchmark" for comparision of different approaches.
if (++consumed < count) {
return@collect emit(value)
} else {
Expand All @@ -70,14 +76,67 @@ private suspend fun <T> FlowCollector<T>.emitAbort(value: T) {

/**
* Returns a flow that contains first elements satisfying the given [predicate].
*
* Note, that the resulting flow does not contain the element on which the [predicate] returned `false`.
* See [transformWhile] for a more flexible operator.
*/
public fun <T> Flow<T>.takeWhile(predicate: suspend (T) -> Boolean): Flow<T> = flow {
try {
collect { value ->
if (predicate(value)) emit(value)
else throw AbortFlowException(this)
// This return is needed to work around a bug in JS BE: KT-39227
return@flow collectWhile { value ->
if (predicate(value)) {
emit(value)
true
} else {
false
}
}
}

/**
* Applies [transform] function to each value of the given flow while this
* function returns `true`.
*
* The receiver of the `transformWhile` is [FlowCollector] and thus `transformWhile` is a
* flexible function that may transform emitted element, skip it or emit it multiple times.
*
* This operator generalizes [takeWhile] and can be used as a building block for other operators.
* For example, a flow of download progress messages can be completed when the
* download is done but emit this last message (unlike `takeWhile`):
*
* ```
* fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
* transformWhile { progress ->
* emit(progress) // always emit progress
* !progress.isDone() // continue while download is not done
* }
* }
* ```
*/
@ExperimentalCoroutinesApi
public fun <T, R> Flow<T>.transformWhile(
@BuilderInference transform: suspend FlowCollector<R>.(value: T) -> Boolean
): Flow<R> =
safeFlow { // Note: safe flow is used here, because collector is exposed to transform on each operation
// This return is needed to work around a bug in JS BE: KT-39227
return@safeFlow collectWhile { value ->
transform(value)
}
}

// Internal building block for non-tailcalling flow-truncating operators
internal suspend inline fun <T> Flow<T>.collectWhile(crossinline predicate: suspend (value: T) -> Boolean) {
val collector = object : FlowCollector<T> {
override suspend fun emit(value: T) {
// Note: we are checking predicate first, then throw. If the predicate does suspend (calls emit, for example)
// the the resulting code is never tail-suspending and produces a state-machine
if (!predicate(value)) {
throw AbortFlowException(this)
}
}
}
try {
collect(collector)
} catch (e: AbortFlowException) {
e.checkOwnership(owner = this)
e.checkOwnership(collector)
}
}
35 changes: 10 additions & 25 deletions kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
Expand Up @@ -82,9 +82,9 @@ public suspend fun <T: Any> Flow<T>.singleOrNull(): T? {
*/
public suspend fun <T> Flow<T>.first(): T {
var result: Any? = NULL
collectUntil {
collectWhile {
result = it
true
false
}
if (result === NULL) throw NoSuchElementException("Expected at least one element")
return result as T
Expand All @@ -96,12 +96,12 @@ public suspend fun <T> Flow<T>.first(): T {
*/
public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
var result: Any? = NULL
collectUntil {
collectWhile {
if (predicate(it)) {
result = it
true
} else {
false
} else {
true
}
}
if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate")
Expand All @@ -114,9 +114,9 @@ public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
var result: T? = null
collectUntil {
collectWhile {
result = it
true
false
}
return result
}
Expand All @@ -127,28 +127,13 @@ public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(predicate: suspend (T) -> Boolean): T? {
var result: T? = null
collectUntil {
collectWhile {
if (predicate(it)) {
result = it
true
} else {
false
} else {
true
}
}
return result
}

internal suspend inline fun <T> Flow<T>.collectUntil(crossinline block: suspend (value: T) -> Boolean) {
val collector = object : FlowCollector<T> {
override suspend fun emit(value: T) {
if (block(value)) {
throw AbortFlowException(this)
}
}
}
try {
collect(collector)
} catch (e: AbortFlowException) {
e.checkOwnership(collector)
}
}
34 changes: 32 additions & 2 deletions kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt
Expand Up @@ -192,7 +192,7 @@ class FlowInvariantsTest : TestBase() {
}

@Test
fun testEmptyCoroutineContext() = runTest {
fun testEmptyCoroutineContextMap() = runTest {
emptyContextTest {
map {
expect(it)
Expand All @@ -212,7 +212,18 @@ class FlowInvariantsTest : TestBase() {
}

@Test
fun testEmptyCoroutineContextViolation() = runTest {
fun testEmptyCoroutineContextTransformWhile() = runTest {
emptyContextTest {
transformWhile {
expect(it)
emit(it + 1)
true
}
}
}

@Test
fun testEmptyCoroutineContextViolationTransform() = runTest {
try {
emptyContextTest {
transform {
Expand All @@ -229,6 +240,25 @@ class FlowInvariantsTest : TestBase() {
}
}

@Test
fun testEmptyCoroutineContextViolationTransformWhile() = runTest {
try {
emptyContextTest {
transformWhile {
expect(it)
withContext(Dispatchers.Unconfined) {
emit(it + 1)
}
true
}
}
expectUnreached()
} catch (e: IllegalStateException) {
assertTrue(e.message!!.contains("Flow invariant is violated"))
finish(2)
}
}

private suspend fun emptyContextTest(block: Flow<Int>.() -> Flow<Int>) {
suspend fun collector(): Int {
var result: Int = -1
Expand Down
@@ -0,0 +1,70 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlin.test.*

class TransformWhileTest : TestBase() {
@Test
fun testSimple() = runTest {
val flow = (0..10).asFlow()
val expected = listOf("A", "B", "C", "D")
val actual = flow.transformWhile { value ->
when(value) {
0 -> { emit("A"); true }
1 -> true
2 -> { emit("B"); emit("C"); true }
3 -> { emit("D"); false }
else -> { expectUnreached(); false }
}
}.toList()
assertEquals(expected, actual)
}

@Test
fun testCancelUpstream() = runTest {
var cancelled = false
val flow = flow {
coroutineScope {
launch(start = CoroutineStart.ATOMIC) {
hang { cancelled = true }
}
emit(1)
emit(2)
emit(3)
}
}
val transformed = flow.transformWhile {
emit(it)
it < 2
}
assertEquals(listOf(1, 2), transformed.toList())
assertTrue(cancelled)
}

@Test
fun testExample() = runTest {
val source = listOf(
DownloadProgress(0),
DownloadProgress(50),
DownloadProgress(100),
DownloadProgress(147)
)
val expected = source.subList(0, 3)
val actual = source.asFlow().completeWhenDone().toList()
assertEquals(expected, actual)
}

private fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
transformWhile { progress ->
emit(progress) // always emit progress
!progress.isDone() // continue while download is not done
}

private data class DownloadProgress(val percent: Int) {
fun isDone() = percent >= 100
}
}

0 comments on commit b42f986

Please sign in to comment.