/
Combine.kt
132 lines (124 loc) · 5.59 KB
/
Combine.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
@file:Suppress("UNCHECKED_CAST", "NON_APPLICABLE_CALL_FOR_BUILDER_INFERENCE") // KT-32203
package kotlinx.coroutines.flow.internal
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
@PublishedApi
internal suspend fun <R, T> FlowCollector<R>.combineInternal(
flows: Array<out Flow<T>>,
arrayFactory: () -> Array<T?>,
transform: suspend FlowCollector<R>.(Array<T>) -> Unit
): Unit = flowScope { // flow scope so any cancellation within the source flow will cancel the whole scope
val size = flows.size
val latestValues = Array<Any?>(size) { NULL }
val isClosed = Array(size) { false }
val resultChannel = Channel<Array<T>>(Channel.CONFLATED)
val nonClosed = LocalAtomicInt(size)
val remainingAbsentValues = LocalAtomicInt(size)
for (i in 0 until size) {
// Coroutine per flow that keeps track of its value and sends result to downstream
launch {
try {
flows[i].collect { value ->
val previous = latestValues[i]
latestValues[i] = value
if (previous === NULL) remainingAbsentValues.decrementAndGet()
if (remainingAbsentValues.value == 0) {
val results = arrayFactory()
for (index in 0 until size) {
results[index] = NULL.unbox(latestValues[index])
}
// NB: here actually "stale" array can overwrite a fresh one and break linearizability
resultChannel.send(results as Array<T>)
}
yield() // Emulate fairness for backward compatibility
}
} finally {
isClosed[i] = true
// Close the channel when there is no more flows
if (nonClosed.decrementAndGet() == 0) {
resultChannel.close()
}
}
}
}
resultChannel.consumeEach {
transform(it)
}
}
internal fun <T1, T2, R> zipImpl(flow: Flow<T1>, flow2: Flow<T2>, transform: suspend (T1, T2) -> R): Flow<R> =
unsafeFlow {
coroutineScope {
val second = asChannel(flow2)
/*
* This approach only works with rendezvous channel and is required to enforce correctness
* in the following scenario:
* ```
* val f1 = flow { emit(1); delay(Long.MAX_VALUE) }
* val f2 = flowOf(1)
* f1.zip(f2) { ... }
* ```
*
* Invariant: this clause is invoked only when all elements from the channel were processed (=> rendezvous restriction).
*/
val collectJob = Job()
val scopeJob = currentCoroutineContext()[Job]!!
(second as SendChannel<*>).invokeOnClose {
// Optimization to avoid AFE allocation when the other flow is done
if (!collectJob.isActive) collectJob.cancel(AbortFlowException(this@unsafeFlow))
}
val newContext = coroutineContext + scopeJob
val cnt = threadContextElements(newContext)
try {
/*
* Non-trivial undispatched (because we are in the right context and there is no structured concurrency)
* hierarchy:
* -Outer coroutineScope that owns the whole zip process
* - First flow is collected by the child of coroutineScope, collectJob.
* So it can be safely cancelled as soon as the second flow is done
* - **But** the downstream MUST NOT be cancelled when the second flow is done,
* so we emit to downstream from coroutineScope job.
* Typically, such hierarchy requires coroutine for collector that communicates
* with coroutines scope via a channel, but it's way too expensive, so
* we are using this trick instead.
*/
withContextUndispatched( coroutineContext + collectJob) {
flow.collect { value ->
val otherValue = second.receiveOrNull() ?: return@collect
withContextUndispatched(newContext, cnt) {
emit(transform(NULL.unbox(value), NULL.unbox(otherValue)))
}
ensureActive()
}
}
} catch (e: AbortFlowException) {
e.checkOwnership(owner = this@unsafeFlow)
} finally {
if (!second.isClosedForReceive) second.cancel(AbortFlowException(this@unsafeFlow))
}
}
}
private suspend fun withContextUndispatched(
newContext: CoroutineContext,
countOrElement: Any = threadContextElements(newContext),
block: suspend () -> Unit
): Unit =
suspendCoroutineUninterceptedOrReturn { uCont ->
withCoroutineContext(newContext, countOrElement) {
block.startCoroutineUninterceptedOrReturn(Continuation(newContext) {
uCont.resumeWith(it)
})
}
}
// Channel has any type due to onReceiveOrNull. This will be fixed after receiveOrClosed
private fun CoroutineScope.asChannel(flow: Flow<*>): ReceiveChannel<Any> = produce {
flow.collect { value ->
return@collect channel.send(value ?: NULL)
}
}