diff --git a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt index e74e9bedea..fe8b3ec4d6 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt @@ -87,7 +87,7 @@ public abstract class CoroutineDispatcher : * private val fileWriterDispatcher = backgroundDispatcher.limitedParallelism(1) * ``` * Note how in this example, the application have the executor with 4 threads, but the total sum of all limits - * is 5. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism. + * is 6. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism. */ @ExperimentalCoroutinesApi public open fun limitedParallelism(parallelism: Int): CoroutineDispatcher { diff --git a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt index 5cbd5b8eef..acb65c53b5 100644 --- a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt @@ -23,9 +23,11 @@ internal class LimitedDispatcher( private val queue = LockFreeTaskQueue(singleConsumer = false) - @InternalCoroutinesApi - override fun dispatchYield(context: CoroutineContext, block: Runnable) { - dispatcher.dispatchYield(context, block) + @ExperimentalCoroutinesApi + override fun limitedParallelism(parallelism: Int): CoroutineDispatcher { + parallelism.checkParallelism() + if (parallelism >= this.parallelism) return this + return super.limitedParallelism(parallelism) } override fun run() { @@ -59,25 +61,47 @@ internal class LimitedDispatcher( } override fun dispatch(context: CoroutineContext, block: Runnable) { - // Add task to queue so running workers will be able to see that - queue.addLast(block) - if (runningWorkers >= parallelism) { - return + dispatchInternal(block) { + if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) { + dispatcher.dispatch(EmptyCoroutineContext, this) + } else { + run() + } + } + } + + @InternalCoroutinesApi + override fun dispatchYield(context: CoroutineContext, block: Runnable) { + dispatchInternal(block) { + dispatcher.dispatchYield(context, this) } + } + private inline fun dispatchInternal(block: Runnable, dispatch: () -> Unit) { + // Add task to queue so running workers will be able to see that + if (tryAdd(block)) return /* - * Protect against race when the worker is finished right after our check. + * Protect against the race when the number of workers is enough, + * but one (because of synchronized serialization) attempts to complete, + * and we just observed the number of running workers smaller than the actual + * number (hit right between `--runningWorkers` and `++runningWorkers` in `run()`) */ + if (enoughWorkers()) return + dispatch() + } + + private fun enoughWorkers(): Boolean { @Suppress("CAST_NEVER_SUCCEEDS") synchronized(this as SynchronizedObject) { - if (runningWorkers >= parallelism) return + if (runningWorkers >= parallelism) return true ++runningWorkers + return false } - if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) { - dispatcher.dispatch(EmptyCoroutineContext, this) - } else { - run() - } + } + + private fun tryAdd(block: Runnable): Boolean { + queue.addLast(block) + return runningWorkers >= parallelism } } diff --git a/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt b/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt index eb1c693639..d994e85330 100644 --- a/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt @@ -33,7 +33,7 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas } @Test - fun testLimited() = runTest { + fun testLimitedExecutor() = runTest { val view = executor.limitedParallelism(targetParallelism) repeat(iterations) { launch(view) { @@ -42,6 +42,27 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas } } + @Test + fun testLimitedDispatchersIo() = runTest { + val view = Dispatchers.IO.limitedParallelism(targetParallelism) + repeat(iterations) { + launch(view) { + checkParallelism() + } + } + } + + @Test + fun testLimitedDispatchersIoDispatchYield() = runTest { + val view = Dispatchers.IO.limitedParallelism(targetParallelism) + repeat(iterations) { + launch(view) { + yield() + checkParallelism() + } + } + } + @Test fun testUnconfined() = runTest { val view = Dispatchers.Unconfined.limitedParallelism(targetParallelism)