From 92748beb48392973a791b20074fd3e8ae2a02332 Mon Sep 17 00:00:00 2001 From: Louis Wasserman Date: Thu, 20 May 2021 15:08:53 -0700 Subject: [PATCH] Add update, updateAndGet, and getAndUpdate extension functions to MutableStateFlow (#2720). --- .../common/src/flow/StateFlow.kt | 57 +++++++++++++++++-- .../common/test/flow/sharing/StateFlowTest.kt | 55 ++++++++++++++++++ 2 files changed, 107 insertions(+), 5 deletions(-) diff --git a/kotlinx-coroutines-core/common/src/flow/StateFlow.kt b/kotlinx-coroutines-core/common/src/flow/StateFlow.kt index da06ec73b9..726ce71585 100644 --- a/kotlinx-coroutines-core/common/src/flow/StateFlow.kt +++ b/kotlinx-coroutines-core/common/src/flow/StateFlow.kt @@ -37,7 +37,7 @@ import kotlin.native.concurrent.* * val counter = _counter.asStateFlow() // publicly exposed as read-only state flow * * fun inc() { - * _counter.value++ + * _counter.update { count -> count + 1 } * } * } * ``` @@ -186,6 +186,56 @@ public interface MutableStateFlow : StateFlow, MutableSharedFlow { @Suppress("FunctionName") public fun MutableStateFlow(value: T): MutableStateFlow = StateFlowImpl(value ?: NULL) +// ------------------------------------ Update methods ------------------------------------ + +/** + * Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns the new + * value. + * + * [function] may be evaluated multiple times, if [value] is being concurrently updated. + */ +inline fun MutableStateFlow.updateAndGet(function: (T) -> T): T { + while (true) { + val prevValue = value + val nextValue = function(prevValue) + if (compareAndSet(prevValue, nextValue)) { + return nextValue + } + } +} + +/** + * Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns its + * prior value. + * + * [function] may be evaluated multiple times, if [value] is being concurrently updated. + */ +inline fun MutableStateFlow.updateAndGet(function: (T) -> T): T { + while (true) { + val prevValue = value + val nextValue = function(prevValue) + if (compareAndSet(prevValue, nextValue)) { + return prevValue + } + } +} + + +/** + * Updates the [MutableStateFlow.value] atomically using the specified [function] of its value. + * + * [function] may be evaluated multiple times, if [value] is being concurrently updated. + */ +inline fun MutableStateFlow.update(function: (T) -> T): T { + while (true) { + val prevValue = value + val nextValue = function(prevValue) + if (compareAndSet(prevValue, nextValue)) { + return + } + } +} + // ------------------------------------ Implementation ------------------------------------ @SharedImmutable @@ -366,10 +416,7 @@ private class StateFlowImpl( } internal fun MutableStateFlow.increment(delta: Int) { - while (true) { // CAS loop - val current = value - if (compareAndSet(current, current + delta)) return - } + update { it + delta } } internal fun StateFlow.fuseStateFlow( diff --git a/kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt b/kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt index 0a2c0458c4..363db7a60b 100644 --- a/kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt @@ -193,4 +193,59 @@ class StateFlowTest : TestBase() { assertTrue(state.compareAndSet(d1_1, d0)) // updates, reference changes assertSame(d0, state.value) } + + @Test + fun testGetAndUpdateContended() = runTest { + val state = MutableStateFlow(0) + + // use a barrier to ensure j2 at least doesn't finish before j3 starts + val barrier = Job() + val j2 = async { + barrier.join() + state.getAndUpdate { it + 2 } + } + val j3 = async { + barrier.join() + state.getAndUpdate { it + 3 } + } + barrier.complete() + when (j2.await()) { + 0 -> assertEquals(2, j3.await()) + 3 -> assertEquals(0, j3.await()) + else -> fail() + } + assertEquals(5, state.value) + } + + @Test + fun testUpdateAndGetContended() = runTest { + val state = MutableStateFlow(0) + + // use a barrier to ensure j2 at least doesn't finish before j3 starts + val barrier = Job() + val j2 = async { + barrier.join() + state.updateAndGet { it + 2 } + } + val j3 = async { + barrier.join() + state.updateAndGet { it + 3 } + } + barrier.complete() + when (j2.await()) { + 5 -> assertEquals(3, j3.await()) + 3 -> assertEquals(5, j3.await()) + else -> fail() + } + assertEquals(5, state.value) + } + + @Test + fun update() = runTest { + val state = MutableStateFlow(0) + state.update { it + 2 } + assertEquals(2, state.value) + state.update { it + 3 } + assertEquals(5, state.value) + } } \ No newline at end of file