Skip to content

Commit

Permalink
Properly check identity of caught AbortFlowException in Flow.first op… (
Browse files Browse the repository at this point in the history
#2057)

It fixes two problems:
    * NoSuchElementException can be thrown during cancellation sequence (see FirstJvmTest that reproduces this problem with explanation)
    * Cancellation can be accidentally suppressed and flow activity can be prolonged

Fixes #2051
  • Loading branch information
qwwdfsad committed May 27, 2020
1 parent 17248c8 commit adbbbaa
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 33 deletions.
65 changes: 32 additions & 33 deletions kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.*
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
import kotlin.jvm.*

/**
Expand Down Expand Up @@ -84,15 +82,10 @@ public suspend fun <T: Any> Flow<T>.singleOrNull(): T? {
*/
public suspend fun <T> Flow<T>.first(): T {
var result: Any? = NULL
try {
collect { value ->
result = value
throw AbortFlowException(NopCollector)
}
} catch (e: AbortFlowException) {
// Do nothing
collectUntil {
result = it
true
}

if (result === NULL) throw NoSuchElementException("Expected at least one element")
return result as T
}
Expand All @@ -103,17 +96,14 @@ public suspend fun <T> Flow<T>.first(): T {
*/
public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
var result: Any? = NULL
try {
collect { value ->
if (predicate(value)) {
result = value
throw AbortFlowException(NopCollector)
}
collectUntil {
if (predicate(it)) {
result = it
true
} else {
false
}
} catch (e: AbortFlowException) {
// Do nothing
}

if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate")
return result as T
}
Expand All @@ -124,13 +114,9 @@ public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
var result: T? = null
try {
collect { value ->
result = value
throw AbortFlowException(NopCollector)
}
} catch (e: AbortFlowException) {
// Do nothing
collectUntil {
result = it
true
}
return result
}
Expand All @@ -141,15 +127,28 @@ public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(predicate: suspend (T) -> Boolean): T? {
var result: T? = null
try {
collect { value ->
if (predicate(value)) {
result = value
throw AbortFlowException(NopCollector)
collectUntil {
if (predicate(it)) {
result = it
true
} else {
false
}
}
return result
}

internal suspend inline fun <T> Flow<T>.collectUntil(crossinline block: suspend (value: T) -> Boolean) {
val collector = object : FlowCollector<T> {
override suspend fun emit(value: T) {
if (block(value)) {
throw AbortFlowException(this)
}
}
}
try {
collect(collector)
} catch (e: AbortFlowException) {
// Do nothing
e.checkOwnership(collector)
}
return result
}
10 changes: 10 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.internal.*
import kotlin.test.*

class FirstTest : TestBase() {
Expand Down Expand Up @@ -160,4 +161,13 @@ class FirstTest : TestBase() {
assertSame(instance, flow.first { true })
assertSame(instance, flow.firstOrNull { true })
}

@Test
fun testAbortFlowException() = runTest {
val flow = flow<Int> {
throw AbortFlowException(NopCollector) // Emulate cancellation
}

assertFailsWith<CancellationException> { flow.first() }
}
}
28 changes: 28 additions & 0 deletions kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import org.junit.Test
import kotlin.test.*

class FirstJvmTest : TestBase() {

@Test
fun testTakeInterference() = runBlocking(Dispatchers.Default) {
/*
* This test tests a racy situation when outer channelFlow is being cancelled,
* inner flow starts atomically in "CANCELLING" state, sends one element and completes
* (=> cancels and drops element away), triggering NSEE in Flow.first operator
*/
val values = (0..10000).asFlow().flatMapMerge(Int.MAX_VALUE) {
channelFlow {
val value = channelFlow { send(1) }.first()
send(value)
}
}.take(1).toList()
assertEquals(listOf(1), values)
}
}

0 comments on commit adbbbaa

Please sign in to comment.