From 8133c973bfa339d89c8df09e2889c9f884fa2b04 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Thu, 31 Mar 2022 03:48:19 -0700 Subject: [PATCH] Fix limitedParallelism implementation on K/N (#3226) The initial implementation predates new memory model and was never working on it Fixes #3223 --- .../common/src/internal/LimitedDispatcher.kt | 9 +-- .../test/LimitedParallelismSharedTest.kt | 34 +++++++++++ .../test/LimitedParallelismConcurrentTest.kt | 59 +++++++++++++++++++ ...mitedParallelismUnhandledExceptionTest.kt} | 25 +------- 4 files changed, 99 insertions(+), 28 deletions(-) create mode 100644 kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt create mode 100644 kotlinx-coroutines-core/concurrent/test/LimitedParallelismConcurrentTest.kt rename kotlinx-coroutines-core/jvm/test/{LimitedParallelismTest.kt => LimitedParallelismUnhandledExceptionTest.kt} (51%) diff --git a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt index 892375b89f..28f37ecf1d 100644 --- a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt @@ -23,6 +23,9 @@ internal class LimitedDispatcher( private val queue = LockFreeTaskQueue(singleConsumer = false) + // A separate object that we can synchronize on for K/N + private val workerAllocationLock = SynchronizedObject() + @ExperimentalCoroutinesApi override fun limitedParallelism(parallelism: Int): CoroutineDispatcher { parallelism.checkParallelism() @@ -50,8 +53,7 @@ internal class LimitedDispatcher( continue } - @Suppress("CAST_NEVER_SUCCEEDS") - synchronized(this as SynchronizedObject) { + synchronized(workerAllocationLock) { --runningWorkers if (queue.size == 0) return ++runningWorkers @@ -87,8 +89,7 @@ internal class LimitedDispatcher( } private fun tryAllocateWorker(): Boolean { - @Suppress("CAST_NEVER_SUCCEEDS") - synchronized(this as SynchronizedObject) { + synchronized(workerAllocationLock) { if (runningWorkers >= parallelism) return false ++runningWorkers return true diff --git a/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt b/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt new file mode 100644 index 0000000000..d01e85716b --- /dev/null +++ b/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlin.test.* + +class LimitedParallelismSharedTest : TestBase() { + + @Test + fun testLimitedDefault() = runTest { + // Test that evaluates the very basic completion of tasks in limited dispatcher + // for all supported platforms. + // For more specific and concurrent tests, see 'concurrent' package. + val view = Dispatchers.Default.limitedParallelism(1) + val view2 = Dispatchers.Default.limitedParallelism(1) + val j1 = launch(view) { + while (true) { + yield() + } + } + val j2 = launch(view2) { j1.cancel() } + joinAll(j1, j2) + } + + @Test + fun testParallelismSpec() { + assertFailsWith { Dispatchers.Default.limitedParallelism(0) } + assertFailsWith { Dispatchers.Default.limitedParallelism(-1) } + assertFailsWith { Dispatchers.Default.limitedParallelism(Int.MIN_VALUE) } + Dispatchers.Default.limitedParallelism(Int.MAX_VALUE) + } +} diff --git a/kotlinx-coroutines-core/concurrent/test/LimitedParallelismConcurrentTest.kt b/kotlinx-coroutines-core/concurrent/test/LimitedParallelismConcurrentTest.kt new file mode 100644 index 0000000000..964f678e74 --- /dev/null +++ b/kotlinx-coroutines-core/concurrent/test/LimitedParallelismConcurrentTest.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.exceptions.* +import kotlin.test.* + +class LimitedParallelismConcurrentTest : TestBase() { + + private val targetParallelism = 4 + private val iterations = 100_000 + private val parallelism = atomic(0) + + private fun checkParallelism() { + val value = parallelism.incrementAndGet() + randomWait() + assertTrue { value <= targetParallelism } + parallelism.decrementAndGet() + } + + @Test + fun testLimitedExecutor() = runMtTest { + val executor = newFixedThreadPoolContext(targetParallelism, "test") + val view = executor.limitedParallelism(targetParallelism) + doStress { + repeat(iterations) { + launch(view) { + checkParallelism() + } + } + } + executor.close() + } + + private suspend inline fun doStress(crossinline block: suspend CoroutineScope.() -> Unit) { + repeat(stressTestMultiplier) { + coroutineScope { + block() + } + } + } + + @Test + fun testTaskFairness() = runMtTest { + val executor = newSingleThreadContext("test") + val view = executor.limitedParallelism(1) + val view2 = executor.limitedParallelism(1) + val j1 = launch(view) { + while (true) { + yield() + } + } + val j2 = launch(view2) { j1.cancel() } + joinAll(j1, j2) + executor.close() + } +} diff --git a/kotlinx-coroutines-core/jvm/test/LimitedParallelismTest.kt b/kotlinx-coroutines-core/jvm/test/LimitedParallelismUnhandledExceptionTest.kt similarity index 51% rename from kotlinx-coroutines-core/jvm/test/LimitedParallelismTest.kt rename to kotlinx-coroutines-core/jvm/test/LimitedParallelismUnhandledExceptionTest.kt index 30c54117a9..8d48aa43b3 100644 --- a/kotlinx-coroutines-core/jvm/test/LimitedParallelismTest.kt +++ b/kotlinx-coroutines-core/jvm/test/LimitedParallelismUnhandledExceptionTest.kt @@ -9,30 +9,7 @@ import java.util.concurrent.* import kotlin.coroutines.* import kotlin.test.* -class LimitedParallelismTest : TestBase() { - - @Test - fun testParallelismSpec() { - assertFailsWith { Dispatchers.Default.limitedParallelism(0) } - assertFailsWith { Dispatchers.Default.limitedParallelism(-1) } - assertFailsWith { Dispatchers.Default.limitedParallelism(Int.MIN_VALUE) } - Dispatchers.Default.limitedParallelism(Int.MAX_VALUE) - } - - @Test - fun testTaskFairness() = runTest { - val executor = newSingleThreadContext("test") - val view = executor.limitedParallelism(1) - val view2 = executor.limitedParallelism(1) - val j1 = launch(view) { - while (true) { - yield() - } - } - val j2 = launch(view2) { j1.cancel() } - joinAll(j1, j2) - executor.close() - } +class LimitedParallelismUnhandledExceptionTest : TestBase() { @Test fun testUnhandledException() = runTest {