Skip to content

Commit

Permalink
LimitedDispatcher fixes
Browse files Browse the repository at this point in the history
    * Support dispatchYield
    * Fix doc
    * Short-circuit limitedParallelism(x).limitedParallelism(y) for y >= x
  • Loading branch information
qwwdfsad committed Sep 20, 2021
1 parent 00122c5 commit ecd36dd
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 16 deletions.
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt
Expand Up @@ -87,7 +87,7 @@ public abstract class CoroutineDispatcher :
* private val fileWriterDispatcher = backgroundDispatcher.limitedParallelism(1)
* ```
* Note how in this example, the application have the executor with 4 threads, but the total sum of all limits
* is 5. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism.
* is 6. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism.
*/
@ExperimentalCoroutinesApi
public open fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
Expand Down
52 changes: 38 additions & 14 deletions kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt
Expand Up @@ -23,9 +23,11 @@ internal class LimitedDispatcher(

private val queue = LockFreeTaskQueue<Runnable>(singleConsumer = false)

@InternalCoroutinesApi
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
dispatcher.dispatchYield(context, block)
@ExperimentalCoroutinesApi
override fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
parallelism.checkParallelism()
if (parallelism >= this.parallelism) return this
return super.limitedParallelism(parallelism)
}

override fun run() {
Expand Down Expand Up @@ -59,25 +61,47 @@ internal class LimitedDispatcher(
}

override fun dispatch(context: CoroutineContext, block: Runnable) {
// Add task to queue so running workers will be able to see that
queue.addLast(block)
if (runningWorkers >= parallelism) {
return
dispatchInternal(block) {
if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
dispatcher.dispatch(EmptyCoroutineContext, this)
} else {
run()
}
}
}

@InternalCoroutinesApi
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
dispatchInternal(block) {
dispatcher.dispatchYield(context, this)
}
}

private inline fun dispatchInternal(block: Runnable, dispatch: () -> Unit) {
// Add task to queue so running workers will be able to see that
if (tryAdd(block)) return
/*
* Protect against race when the worker is finished right after our check.
* Protect against the race when the number of workers is enough,
* but one (because of synchronized serialization) attempts to complete,
* and we just observed the number of running workers smaller than the actual
* number (hit right between `--runningWorkers` and `++runningWorkers` in `run()`)
*/
if (enoughWorkers()) return
dispatch()
}

private fun enoughWorkers(): Boolean {
@Suppress("CAST_NEVER_SUCCEEDS")
synchronized(this as SynchronizedObject) {
if (runningWorkers >= parallelism) return
if (runningWorkers >= parallelism) return true
++runningWorkers
return false
}
if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
dispatcher.dispatch(EmptyCoroutineContext, this)
} else {
run()
}
}

private fun tryAdd(block: Runnable): Boolean {
queue.addLast(block)
return runningWorkers >= parallelism
}
}

Expand Down
23 changes: 22 additions & 1 deletion kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt
Expand Up @@ -33,7 +33,7 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas
}

@Test
fun testLimited() = runTest {
fun testLimitedExecutor() = runTest {
val view = executor.limitedParallelism(targetParallelism)
repeat(iterations) {
launch(view) {
Expand All @@ -42,6 +42,27 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas
}
}

@Test
fun testLimitedDispatchersIo() = runTest {
val view = Dispatchers.IO.limitedParallelism(targetParallelism)
repeat(iterations) {
launch(view) {
checkParallelism()
}
}
}

@Test
fun testLimitedDispatchersIoDispatchYield() = runTest {
val view = Dispatchers.IO.limitedParallelism(targetParallelism)
repeat(iterations) {
launch(view) {
yield()
checkParallelism()
}
}
}

@Test
fun testUnconfined() = runTest {
val view = Dispatchers.Unconfined.limitedParallelism(targetParallelism)
Expand Down

0 comments on commit ecd36dd

Please sign in to comment.